I have the following code snippet:
import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]
How do I generalize that for any a(n,3,3), b(n,3) and c(n,3)?
I think einsum is the way to go but I can’t quite figure the right syntax…
>Solution :
you could broadcast or use einsum (better einsum):
import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]
res_broad = (a*b[:,None,:]).sum(2)
res_ein = np.einsum('ijk,ik->ij',a,b)
print(f"broadcast works: {np.allclose(c,res_broad)}")
print(f"einsum works: {np.allclose(c,res_broad)}")