Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

How can I weigh each batch differently in pytorch CrossEntropyLoss

My understanding is that torch.nn.CrossEntropyLoss creates loss value for each item of the batch and sums it to give final loss. Suppose in the following toy example we want to multiply the loss of each of 3 items in batch with corresponding elements in an array multiplier before summing. How do I achieve this?

import torch
import numpy as np
multiplier = torch.from_numpy(np.array([1.0, -2.0, 5.0]))
source = torch.from_numpy(
    np.array([[0.5, -0.6],
              [-3.0, -2.0],
              [-4.0, 2.3]]))
target = torch.from_numpy(np.array([0, 1, 0]))
loss_fn = torch.nn.CrossEntropyLoss()
loss_fn(source, target)

Note that the weight argument of the function weighs the various target indices differently, that is not what I want.

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

You can do it manually by not reducing the loss (not aggregating it by taking an average, and then doing your weighted sum by hand:

loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
loss = torch.sum( batch_weights*loss_fn(source, target)) / torch.sum(batch_weights)

see https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading