Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

How do we label a matrix plot with a predefined list of labels (e.g. a list of prime numbers)?

I have a CSV File with elliptic curve related data which looks as follows (excerpt):

3,19,[(1,4,1),(4,7,44)]
3,13,[(4,1,10)]
5,11,[(4,39,14)]
3,7,[(1,2,1),(1,3622,111),(4,17,10)]
3,5,[(4,1,2),(4,959,124)]
3,23,[(3,5,2)]
5,13,[(2,7,2),(2,8,1),(2,47,70),...]
7,11,[(1,2,1),(1,13,2),(1,53,4),...]
3,29,[(4,13,4)]
5,17,[(2,2,3),(4,6881,498)]
5,19,[(4,71,18)]
3,37,[(4,25,14)]
7,17,[(4,19,30)]
7,19,[(3,66,5)]

The first and second number in each line forms a pair of odd prime numbers and the list after these two primes is used to specify the color of the corresponding matrix cell. I want to generate a matrix plot where the x-axis and y-axis are labeled with the first 200 (odd) primes: 3, 5, 7, 11, 13, 17, 19, 23, ..., 1193, 1201, 1213, 1217, 1223, 1229. The following code works fine to generate the colorized matrix plot.

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from sympy import sieve, prime
import itertools
import pandas as pd

matrix_size = 200
matrix = np.zeros((matrix_size, matrix_size))

with open('C:/Users/esultano/git/elliptic_curves/data/elliptic_curves.csv') as f:
    for line in f.readlines():
        line = line.strip()
        p, q, cases = eval(line)
        idx_p = sieve.search(p)[0]-2
        idx_q = sieve.search(q)[0]-2
        if idx_p < matrix_size and idx_q < matrix_size:
            cases_set = set()
            for case in cases:
                cases_set.add(case[0])
            cases_list = list(cases_set)
            val = sum(i*i for i in cases_list)
            matrix[idx_p,idx_q] = val
            matrix[idx_q,idx_p] = val

max_prime = prime(matrix_size+2)
axis_labels = list(sieve.primerange(3, max_prime))

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
matrix_plot = ax.matshow(matrix, interpolation='nearest')
ax.set_xticklabels(axis_labels)
ax.set_yticklabels(axis_labels)
plt.show()

The plot looks as follows:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

enter image description here

Unfortunatelly the axis labeling is not correct since each axis should span the primes from 3 to 1229. How can I fix the axis labeling? I am not requiring to place each of these 200 primes as axis label (displaying a few labels would be fine).

>Solution :

As documentation of Axes.set_xticklabels() says, obligatory you have to call Axes.set_xticks() before.

I did necessary modifications to your code, so that it works now. You can see differences in code here online or using any text-diff tools, like WinDiff.exe in standard Windows. Screenshot of output goes after code.

Try it online!

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from sympy import sieve, prime
import itertools
import pandas as pd

matrix_size = 200
matrix = np.zeros((matrix_size, matrix_size))

with open('elliptic_curves.csv') as f:
    for line in f.readlines():
        line = line.strip()
        p, q, cases = eval(line)
        idx_p = sieve.search(p)[0]-2
        idx_q = sieve.search(q)[0]-2
        if idx_p < matrix_size and idx_q < matrix_size:
            cases_set = set()
            for case in cases:
                cases_set.add(case[0])
            cases_list = list(cases_set)
            val = sum(i*i for i in cases_list)
            matrix[idx_p,idx_q] = val
            matrix[idx_q,idx_p] = val

max_prime = prime(matrix_size+2)
axis_labels = list(enumerate(sieve.primerange(3, max_prime)))
axis_labels = axis_labels[::len(axis_labels) // 9][:-1] + [axis_labels[-1]]
ticks = [e[0] for e in axis_labels]
ticklabels = [e[1] for e in axis_labels]

fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111)
matrix_plot = ax.matshow(matrix, interpolation='nearest')
ax.set_xticks(ticks); ax.set_xticklabels(ticklabels)
ax.set_yticks(ticks); ax.set_yticklabels(ticklabels)
plt.show()

Output:

enter image description here

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading