I am studying some code and I came across a usage of PyTorch’s einsum function that I am not understanding. The docs are here.
The snippet looks like (slightly modified from the original):
import torch
x = torch.rand(64, 64, 25, 25)
y = torch.rand(64, 64, 64, 25)
result = torch.einsum('ncuv,nctv->nctu', x, y)
print(result.shape)
>> torch.Size([64, 64, 64, 25])
So the notation is such that n=64, c=64, u=25, v=25, t=64.
I’m not too sure what’s happening. I think that for each 25 dimensional vector in t (64 of them), each one is being multiplied with each of the u=25 vectors of size 25 elementwise and then the results summed, or rather 25 dot products of 25 dimensional vectors?
Any insights appreciated.
>Solution :
Basically, you can think of it as taking dot products over certain dimensions, and reorganizing the rest.
For simplicity, let’s ignore the batching dimensions n and c (since they are consistent before and after ncuv,nctv->nctu), and discuss:
import torch
x = torch.rand(25, 25)
y = torch.rand(64, 25)
result = torch.einsum('uv,tv->tu', x, y)
print(result.shape)
>> torch.Size([64, 25])
Note that v vanishes after einsum, meaning v is the dimension being summed up, while t and u are not. You can interpret it this way: x is a collection of 25 25-dimensional vectors; y is a collection of 64 25-dimensional vectors. The dot product of the t-th vector in y and the u-th vector in x are computed and put in the t-th row and u-th column of result.
You can also rewrite into a math equation:
result[n,c,t,u] = \sum_{v} x[n,c,u,v] * y[n,c,t,v], for each n, c, t, u
Note two things:
- the summation is over the indices that vanish in the summation pattern
nctu,ncuv->nctv - indices appearing on the right of the pattern are the indices of the resulting tensor