Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Unflatten in pytorch

I need to change the shape of tensor from [2, 48, 196] to [2, 48, 14,14]. I read there a "unflatten" in pytorch. But I couldn’t understand how to use it.
Is there any example?

>Solution :

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

Here is example for your question.

import torch

input = torch.randn([2,48,196])
unflatten = torch.nn.Unflatten(2, (14,14))
output = unflatten(input)

If you check output.shape, the shape is [2,48,14,14].

Unflatten function is to expand specific dim to a desired shape. In your case, you want to expand the shape 196 in "dim 2" to new shape of the unflatten dimension "(14,14)".

There are two parameters in Unflatten function.

  1. First parameter is dim. it is specific dimension which you want to
    be unflatten. In your case, it is 2.
  2. Second parameter is unflatten_size. It is the new shape of the unflatten dimension of the tensor. So it is (14,14).

Therefore, your Unflatten function should be looked like unflatten = torch.nn.Unflatten(2, (14,14))

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading