Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

How does torch differentiate between batch and single values?

Here’s my neural network.

from torch import nn
from torch.utils.data import DataLoader
class NeuralNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(NeuralNetwork, self).__init__()
        self.state_size = state_size
        self.action_size = action_size
#         self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(1, 30),
            nn.ReLU(),
            nn.Linear(30, 30),
            nn.ReLU(),
            nn.Linear(30, action_size)
        )
    def forward(self, x):
        x = self.linear_relu_stack(x)
        return x

As you can imagine, it can compute tensor inputs of the shape torch.Size([1]). However, when I try to feed it batch data, for instance, shape torch.Size([10]) it throws the following error –

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x10 and 1x30)

For instance, this code works –

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

net = NeuralNetwork(10, 5)
x1 = torch.rand(1)
print(x1.shape)
out = net(x1)

But this fails –

x2 = torch.rand(10)
print(x2.shape)
out = net(x2)

>Solution :

You just need to change your inputs a little bit. Your code is expecting the net() to have a second dimension of 1, so it can multiply by nn.linear(1,30). The inner dimensions must match for matrix multiplication to occur. I.e. 10×1 * 1×30:

x2 = torch.rand(10,1)
print(x2.shape)
out = net(x2)
torch.Size([10, 1])

Try this with any dimension x2 = torch.rand(100,1) etc. it still works.

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading