The shape of the tensor is [5, 2, 18, 4096]. I want to take each tensor along 0th dimension of size [2, 18, 4096] and stack it on top of another tensor which is of shape from the same tensor [2, 18, 4096] and do it for all tensors along the 0th dimension. The final tensor should be [2, 90, 4096].
>Solution :
It turns out there’s a very simple approach: torch.hstack always stacks along the second dimension (i.e. along axis 1). For instance, consider the following:
start = torch.arange(120).reshape([5,4,3,2])
result = torch.hstack(list(start))
The tensor start has shape 5,4,3,2, but result has shape (4,15,2), which comes from stacking 5 (4,3,2) arrays along axis 1.
Applying list to a multidimensional tensor breaks the tensor up along the main axis. In this case, list(start) is a list containing 5 (4,3,2)-shaped tensors.