Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Understanding an example of PyTorch's Einsum function

I am studying some code and I came across a usage of PyTorch’s einsum function that I am not understanding. The docs are here.

The snippet looks like (slightly modified from the original):

import torch
x = torch.rand(64, 64, 25, 25)
y = torch.rand(64, 64, 64, 25)
result = torch.einsum('ncuv,nctv->nctu', x, y)
print(result.shape)
>> torch.Size([64, 64, 64, 25])

So the notation is such that n=64, c=64, u=25, v=25, t=64.

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

I’m not too sure what’s happening. I think that for each 25 dimensional vector in t (64 of them), each one is being multiplied with each of the u=25 vectors of size 25 elementwise and then the results summed, or rather 25 dot products of 25 dimensional vectors?

Any insights appreciated.

>Solution :

Basically, you can think of it as taking dot products over certain dimensions, and reorganizing the rest.

For simplicity, let’s ignore the batching dimensions n and c (since they are consistent before and after ncuv,nctv->nctu), and discuss:

import torch
x = torch.rand(25, 25)
y = torch.rand(64, 25)
result = torch.einsum('uv,tv->tu', x, y)
print(result.shape)
>> torch.Size([64, 25])

Note that v vanishes after einsum, meaning v is the dimension being summed up, while t and u are not. You can interpret it this way: x is a collection of 25 25-dimensional vectors; y is a collection of 64 25-dimensional vectors. The dot product of the t-th vector in y and the u-th vector in x are computed and put in the t-th row and u-th column of result.

You can also rewrite into a math equation:

result[n,c,t,u] = \sum_{v} x[n,c,u,v] * y[n,c,t,v], for each n, c, t, u

Note two things:

  • the summation is over the indices that vanish in the summation pattern nctu,ncuv->nctv
  • indices appearing on the right of the pattern are the indices of the resulting tensor
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading