Use class object as method to emulate a callable object in Python

Advertisements

How can you define in Python3 a class MyClass such that you instantiate it like
obj = MyClass(param1, param2) and then use it to compute an operation like res = obj(in1, in2, in3) ?

For instance, with PyTorch you can declare a model as mod = MyResNet50() and then compute its prediction as pred = mod(input).

Below is the code I tried. I declare a method and call it as obj.method().

import numpy as np


class MLP:

    def __init__(self, hidden_units: int, input_size: int):
        self.hidden_units = hidden_units
        self.input_size = input_size
        self.layer1 = np.random.normal(0, 0.01, size=(hidden_units, input_size))
        self.layer2 = np.random.normal(0, 0.01, size=(1, hidden_units))

    def sigmoid(self, z):
        return 1/(1 + np.exp(-z))

    def predict(self, input):
        pred = self.layer1.dot(input)
        pred = self.layer2.dot(pred)
        return self.sigmoid(pred)

my_MLP = MLP(5, 10)
pred = my_MLP.predict(np.random.normal(0, 0.01, 10))

>Solution :

Implement __call__ to react to the class instance being called with ():

class MyClass:
    def __init__(self, p1, p2):
        # here you would initiate with p1, p2
        pass 
        
    def __call__(self, in1, in2, in3):
        return f'You rang? {in1} {in2} {in3}'

Example:

>>> obj=MyClass(1,2)    
>>> res=obj(1,2,3)
>>> res
You rang? 1 2 3

If your class instance does not have __call__ defined (either itself or what it descends from) it will let you know:

class MyClass:
    def __init__(self, p1, p2):
        # here you would initiate with p1, p2
        pass 
    # no other methods and descendant of Object...

>>> MyClass(1,2)()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'MyClass' object is not callable

Leave a ReplyCancel reply