Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

freezing layers in a neural network in pytorch

I have a cascaded neural network whereby the the output of first network become the input of second network. The first neural network is pretrained so I just initialise it with those pretrained weights. However, I want to freeze the first neural network so that when training its only updating weights of the second neural network. How can I do that? My network looks like:


###First network

class LambdaBase(nn.Sequential):
    def __init__(self, fn, *args):
        super(LambdaBase, self).__init__(*args)
        self.lambda_func = fn

    def forward_prepare(self, input):
        output = []
        for module in self._modules.values():
            output.append(module(input))
        return output if output else input

class Lambda(LambdaBase):
    def forward(self, input):
        return self.lambda_func(self.forward_prepare(input))

class LambdaMap(LambdaBase):
    def forward(self, input):
        return list(map(self.lambda_func,self.forward_prepare(input)))

class LambdaReduce(LambdaBase):
    def forward(self, input):
        return reduce(self.lambda_func,self.forward_prepare(input))

def get_first_model(load_weights = True):
    pretrained_model_reloaded_th = nn.Sequential( # Sequential,
        nn.Conv2d(4,300,(19, 1)),
        nn.BatchNorm2d(300),
        nn.ReLU(),
        nn.MaxPool2d((3, 1),(3, 1)),
        nn.Conv2d(300,200,(11, 1)),
        nn.BatchNorm2d(200),
        nn.ReLU(),
        nn.MaxPool2d((4, 1),(4, 1)),
        nn.Conv2d(200,200,(7, 1)),
        nn.BatchNorm2d(200),
        nn.ReLU(),
        nn.MaxPool2d((4, 1),(4, 1)),
        Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2000,1000)), # Linear,
        nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,1000)), # Linear,
        nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,164)), # Linear,
        nn.Sigmoid(),
    )
    if load_weights:
        sd = torch.load('pretrained_model.pth')
        pretrained_model_reloaded_th.load_state_dict(sd)
    return  pretrained_model_reloaded_th

### second network

def next_model_architecture():
    next_model = nn.Sequential(
    nn.Linear(164, 64),
    nn.ReLU(),
    nn.Linear(64, 1),
    nn.Sigmoid())
    
    return next_model

### joining two networks
def cascading_model(first_model,next_model):
    network = nn.Sequential(first_model, next_model)
    return network


first_model = get_first_model(load_weights = True)
next_model = next_model_architecture()
network = cascading_model(first_model,next_model)

If I do:


first_model = first_model.eval()

Will this freeze my first neural network and only update weights of second network during training?

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

Freezing any parameter is done by setting it’s .requires_grad to False. Do so by iterating over all parameters of the module (that you want to freeze)

for p in first_model.parameters():
    p.requires_grad = False
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading