Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Calling a custom module in pythorch with two parameters

I’ve tried to create three custom modules as follows:

import torch

class VerySimple(torch.nn.Module):
  def __init__(self):
    super(VerySimple, self).__init__()

  def forward(self, x):
    return x * 3.0

class VerySimple2(torch.nn.Module):
  def __init__(self):
    super(VerySimple, self).__init__()

  def forward(self, x, y):
    return x * y * 3.0

After that I created two very simple networks as such:

vs = VerySimple()
vs2 = VerySimple2()
print(vs(2.0))
print(vs2(2.0, 3.0))

The examples work as I expect when I call this which outputs 6.0 and 18.0

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

Now I try to create something a little more interesting like so:

class Simple2(torch.nn.Module):
  def __init__(self):
    super(Simple2, self).__init__()
    self.model1 = torch.nn.Sequential(
        torch.nn.Linear(1, 3),
        torch.nn.ReLU(),
        torch.nn.Linear(3, 1)
    )
    self.model2 = torch.nn.Sequential(
        torch.nn.Linear(1, 3),
        torch.nn.ReLU(),
        torch.nn.Linear(3, 1)
    )

  def forward(self, x, y):
    x1 = self.model1(x)
    y2 = self.model2(y)
    return torch.cat((x1,y2),1)

But now when I get an "AttributeError" with the code below:

s2 = Simple2()
s2(2,3)

What am I doing wrong with the s2(2,3)?

Alternatively: What is the minimal working example with s2(2,3)?

As requested I add the full log here:

6.0
18.0
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-3-f3b0dc51220a> in <module>
     43 
     44 s2 = Simple2()
---> 45 s2(2,3)

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-3-f3b0dc51220a> in forward(self, x, y)
     38 
     39   def forward(self, x, y):
---> 40     x1 = self.model1(x)
     41     y2 = self.model2(y)
     42     return torch.cat((x1,y1),1)

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
    115     def forward(self, input):
    116         for module in self:
--> 117             input = module(input)
    118         return input
    119 

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
     91 
     92     def forward(self, input: Tensor) -> Tensor:
---> 93         return F.linear(input, self.weight, self.bias)
     94 
     95     def extra_repr(self) -> str:

/opt/app-root/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1686         if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
   1687             return handle_torch_function(linear, tens_ops, input, weight, bias=bias)
-> 1688     if input.dim() == 2 and bias is not None:
   1689         # fused op is marginally faster
   1690         ret = torch.addmm(bias, input, weight.t())

AttributeError: 'int' object has no attribute 'dim'

I tried the example with tensors below from Tamir as such:

x = torch.tensor([2.0])
y = torch.tensor([3.0])
s2(x,y)

But I end up with this error instead:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-8-c61a0803c9b9> in <module>
     43 
     44 s2 = Simple2()
---> 45 s2(torch.tensor([2.0]), torch.tensor([3.0]))
     46 

/opt/app-root/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-8-c61a0803c9b9> in forward(self, x, y)
     40     x1 = self.model1(x)
     41     y2 = self.model2(y)
---> 42     return torch.cat((x1,y2),1)
     43 
     44 s2 = Simple2()

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

Last NOTE:
I had to modify the Simple2 example to this instead to get it to work with Tamir’s solution:

class Simple2(torch.nn.Module):
  def __init__(self):
    super(Simple2, self).__init__()
    self.model1 = torch.nn.Sequential(
        torch.nn.Linear(1, 3),
        torch.nn.ReLU(),
        torch.nn.Linear(3, 1)
    )
    self.model2 = torch.nn.Sequential(
        torch.nn.Linear(1, 3),
        torch.nn.ReLU(),
        torch.nn.Linear(3, 1)
    )

  def forward(self, x, y):
    x1 = self.model1(x)
    y2 = self.model2(y)
    # replaced this with row below: return torch.cat((x1,y2),1)
    return x1 + y2

>Solution :

This is probably a type issue, Pytorch Linear and ReLU layer expect Tensors as inputs and your are passing integers.
Do something like

x = torch.tensor([2])
y = torch.tensor([3])
s2(x,y)
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading