I will illustrate my question with an example. If I have the array:
a = np.array([[1,2,3,4],
[9,10,11,12],
[17,18,19,20],
[25,26,27,28]])
I would like to get
array([[[ 1, 2], [[ 3, 4],
[9, 10], [11, 12],
[17, 18], [19, 20],
[25, 26]] [27, 28]],
So apparently if my array was MxN , now it will be Mx(N/2)x2. How to do it? I tried:
import numpy as np
# pre-computed data
data.reshape(data.shape[0], data.shape[1]//2, 2)
, does not work as expected
>Solution :
Use np.split (on column axis) + np.stack combination:
arr = np.stack(np.split(a, 2, axis=1))
print(arr)
[[[ 1 2]
[ 9 10]
[17 18]
[25 26]]
[[ 3 4]
[11 12]
[19 20]
[27 28]]]