There is a tensor with the shape of
[b, nt*nh*nw, dim]. The values of nt, nh, and nw are in hand. How can I reshape this tensor to the form of [b, dim, nt, nh, nw]? For example, how it is possible to reshape [2, 3x2x4, 512] to [2,512,3,2,4]?
It all depends on your data layout in memory.
nw are in the correct ordering in your underlying data tensor then you can do so by permuting and reshaping your tensor.
>>> x.transpose(1,2).view(b, dim, nt, nh, nw)