Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Plotting heatmap with gridlines in matplotlib misses gridlines

I am trying to plot a heatmap with gridlines. This is my code (adapted from this post):

# Plot a heatmap with gridlines

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from functional import seq

arr = np.random.randn(3, 20)

plt.tight_layout()
ax = plt.subplot(111)
ax.imshow(arr, cmap='viridis')

xr = ax.get_xlim()
yr = ax.get_ylim()
ax.set_xticks(np.arange(max(xr))-0.5, minor=True)
ax.set_yticks(np.arange(max(yr))-0.5, minor=True)
ax.grid(which='minor', snap=False, color='k', linestyle='-', linewidth=1)
ax.tick_params(which='major', bottom=False, left=False)
ax.tick_params(which='minor', bottom=False, left=False)
ax.set_xticklabels([])
ax.set_yticklabels([])
for spine in ax.spines.values():
    spine.set_visible(False)


plt.show()

I get the following plot (cropped to content):

bad heatplot

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

There are no vertical grid lines after the 3rd, 8th, and third to last columns. As this script uses minor ticks to plot grid lines, I also thought to print the xticks and make the same plot without hiding the labels:

print('xticks:', np.arange(max(xr))-0.5)

xticks: [-0.5  0.5  1.5  2.5  3.5  4.5  5.5  6.5  7.5  8.5  9.5 10.5 11.5 12.5
 13.5 14.5 15.5 16.5 17.5 18.5]

bad heatmap with labels

which show that all the necessary ticks are there. What could be the issue here?

>Solution :

I have no idea why this happens. Using the major grid instead fixes this issue:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

arr = np.random.randn(3, 20)

plt.tight_layout()
ax = plt.subplot(111)
ax.imshow(arr, cmap='viridis')

xr = ax.get_xlim()
yr = ax.get_ylim()
ax.set_xticks(np.arange(max(xr))-0.5, minor=False)
ax.set_yticks(np.arange(max(yr))-0.5, minor=False)

ax.grid(which='major', snap=False, color='k', linestyle='-', linewidth=1)

ax.tick_params(which='major', bottom=False, left=False)
ax.tick_params(which='minor', bottom=False, left=False)

plt.show()
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading