Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Efficient masked argsort in Numpy

I have a numpy array such as this one:

arr = np.array([
    [1, 2, 3],
    [4, -5, 6],
    [-1, -1, -1]
])

And I would like to argsort it, but with a arr <= 0 mask. The output should be:

array([[0, 1, 2],
       [0, 2],       # (Note that the indices are still relative to original un-masked array)
       []])

However, the output I get using np.ma.argsort() is:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

array([[0, 1, 2],
       [0, 2, 1],
       [0, 1, 2]])

The approach needs to be very efficient because the real array has millions of columns. I am thinking this needs to be a synthesis of a few operations, but I don’t know which ones.

>Solution :

The np.where approach:

Input array

arr = np.array([
    [1, 2, 3],
    [4, -5, 6],
    [-1, -1, -1]
])

Mask of valid elements

mask = arr > 0

Preallocate result as an object array to hold variable-length indices

result = np.empty(arr.shape[0], dtype=object)

Efficient masked argsort for each row

for i in range(arr.shape[0]):
    valid_indices = np.where(mask[i])[0]  # Get indices of valid (masked) elements
    result[i] = valid_indices[np.argsort(arr[i, valid_indices])]  # Sort valid indices by their values

Output:

[array([0, 1, 2]) array([0, 2]) array([], dtype=int64)]

The np.flatnonzero approach:

A more optimised approach using vectorised operations:

def optimized_masked_argsort(arr, mask):
    result = np.empty(arr.shape[0], dtype=object)
    for i in range(arr.shape[0]):
        row = arr[i]
        valid_indices = np.flatnonzero(mask[i])  # Faster than np.where(mask[i])[0]
        valid_values = row[valid_indices]
        sorted_order = np.argsort(valid_values)
        result[i] = valid_indices[sorted_order]
    return result

Comparison:

Timings for given example:
np.where Time: 0.000034 seconds
np.flatnonzero Time: 0.000017 seconds

Timings for larger array (1000 rows):
np.where Time: 0.001856 seconds
np.flatnonzero Time: 0.001754 seconds

I tried a few other methods but they fell short in efficiency.

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading