Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Accumulating gradients for a larger batch size with PyTorch

In order to mimick a larger batch size, I want to be able to accumulate gradients every N batches for a model in PyTorch, like:

def train(model, optimizer, dataloader, num_epochs):
     model.train()
     model.cuda()
     for epoch_num in range(1, num_epochs+1):
         for batch_num, data in enumerate(dataloader):
             ims = data.to('cuda:0') 
             loss = model(ims)
             loss.backward()
             if batch_num % N == 0 and batch_num != 0:
                 optimizer.step()
                 optimizer.zero_grad(set_to_none=True)

For this approach do I need to add the flag retain_graph=True, i.e.

loss.backward(retain_graph=True)

In this manner, are the gradients per each backward call simply summed per each parameter?

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

You need to set retain_graph=True if you want to make multiple backward passes over the same computational graph, making use of the intermediate results from a single forward pass. This would have been the case, for instance, if you called loss.backward() multiple times after computing loss once, or if you had multiple losses from different parts of the graph to backpropagate from (a good explanation can be found here).

In your case, for each forward pass, you backpropagate exactly once. So you don’t need to store the intermediate results from the computational graph once the gradients are computed.

In short:

  • Intermediate outputs in the graph are cleared after a backward pass, unless explicitly preserved using retain_graph=True.
  • Gradients accumulate by default, unless explicitly cleared using zero_grad.
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading