Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Reshaping the dimension of a tensor in PyTorch

There is a tensor with the shape of [b, nt*nh*nw, dim]. The values of nt, nh, and nw are in hand. How can I reshape this tensor to the form of [b, dim, nt, nh, nw]? For example, how it is possible to reshape [2, 3x2x4, 512] to [2,512,3,2,4]?

>Solution :

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

It all depends on your data layout in memory.

However, assuming nt, nh, and nw are in the correct ordering in your underlying data tensor then you can do so by permuting and reshaping your tensor.

First swap dimensions to place dim as the 2nd axis using torch.transpose or torch.permute. Then reshape the tensor to the desired shape with torch.view or torch.reshape:

>>> x.transpose(1,2).view(b, dim, nt, nh, nw)
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading