Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

How to plot imshow legend based on colormap?

I am trying to plot a legend based on the colormap of my imshow plot.

I could do that by specifying the colors and labels I want with matplotlib.mpatches and matplotlib.colors, as I saw in some examples, but I just would like to follow the colors from a cmap. For example:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Data
df = pd.DataFrame({"a":[0.1, 0.3, 0.4, 0.2, 0.5, 0.6],
                   "z":[5, 10, 15, 20, 25, 30],
                   "class":[1, 2, 3, 1, 3, 3]})

# 'Class' 2D cluster 
cluster = np.repeat(np.expand_dims(df["class"].values, 1), 10, 1)

f, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,10))
ax[0].plot(df["a"], df["z"])
ax[1].imshow(cluster, interpolation='none', cmap="Paired",
             aspect='auto', vmin=1, vmax=3, extent=[0, 1, df["z"].min(), df["z"].max()])

ax[0].set_ylim(df["z"].min(), df["z"].max())

ax[1].legend()

Here I would like to plot the legend based on the cmap "Paired" from my imshow plot, and with this code I just get "No handles with labels found to put in legend". Anyone could help me?

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

I need something like this:
enter image description here

>Solution :

You need to create a custom legend, using the same cmap and norm (vmin and vmax) as in ax.imshow(). For each unique "class", a handle is created using a dummy rectangle.

Here is an example:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Data
df = pd.DataFrame({"a": [0.1, 0.3, 0.4, 0.2, 0.5, 0.6],
                   "z": [5, 10, 15, 20, 25, 30],
                   "class": [1, 2, 3, 1, 3, 3]})

cluster = df["class"].values.reshape(-1, 1)

f, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 10))
ax[0].plot(df["a"], df["z"])
norm = plt.Normalize(vmin=df["class"].min(), vmax=df["class"].max())
cmap = plt.get_cmap('Paired')
ax[1].imshow(cluster, interpolation='none', cmap=cmap, norm=norm,
             aspect='auto', extent=[0, 1, df["z"].min(), df["z"].max()])

ax[0].set_ylim(df["z"].min(), df["z"].max())

handles = [plt.Rectangle((0, 0), 0, 0, color=cmap(norm(i)), label=i) for i in df["class"].unique()]
ax[1].legend(handles=handles, title='class')
plt.tight_layout()
plt.show()

imshow with custom legend

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading