I would like to extract elements of a given multidimensional numpy array, using another array of indices. However it doesn’t behave in the way I expected. Below is a simple example:
import numpy as np
a = np.random.random((3, 3, 3))
idx = np.asarray([[0, 0, 0], [0, 1, 2]])
b = a[idx]
print(b.shape) # expect (2, ), got (2, 3, 3, 3)
Why is that the case? And how should I modify the code so that b contains only two elements: a[0, 0, 0] and a[0, 1, 2]?
>Solution :
You are looking for numpy advanced indexing
https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing
In your case, you need to use idx on each axis:
a[idx[:,0], idx[:, 1], idx[:, 2]].shape == (2,) # True