Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

np.where for 2d array, manipulate whole rows

I want to rebuild the following logic with numpy broadcasting function such as np.where: From a 2d array check per row if the first element satisfies a condition. If the condition is true then return the first three elements as a row, else the last three elements.

A short MWE in form of a for-loop which I want to circumvent:

import numpy as np
array = np.array([
    [1, 2, 3, 4],
    [1, 2, 4, 2],
    [2, 3, 4, 6]
])

new_array = np.zeros((array.shape[0], array.shape[1]-1))
for i, row in enumerate(array):
    if row[0] == 1: new_array[i] = row[:3]
    else: new_array[i] = row[-3:]

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

If you want to use np.where:

import numpy as np
array = np.array([
    [1, 2, 3, 4],
    [1, 2, 4, 2],
    [2, 3, 4, 6]
])

cond = array[:, 0] == 1
np.where(cond[:, None], array[:,:3], array[:,-3:])

output:

array([[1, 2, 3],
       [1, 2, 4],
       [3, 4, 6]])

EDIT

slightly more concise version:

np.where(array[:, [0]] == 1, array[:,:3], array[:,-3:])
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading