Reshaping the dimension of a tensor in PyTorch

There is a tensor with the shape of [b, nt*nh*nw, dim]. The values of nt, nh, and nw are in hand. How can I reshape this tensor to the form of [b, dim, nt, nh, nw]? For example, how it is possible to reshape [2, 3x2x4, 512] to [2,512,3,2,4]?

>Solution :

It all depends on your data layout in memory.

However, assuming nt, nh, and nw are in the correct ordering in your underlying data tensor then you can do so by permuting and reshaping your tensor.

First swap dimensions to place dim as the 2nd axis using torch.transpose or torch.permute. Then reshape the tensor to the desired shape with torch.view or torch.reshape:

>>> x.transpose(1,2).view(b, dim, nt, nh, nw)

Leave a Reply