My understanding is that torch.nn.CrossEntropyLoss creates loss value for each item of the batch and sums it to give final loss. Suppose in the following toy example we want to multiply the loss of each of 3 items in batch with corresponding elements in an array multiplier before summing. How do I achieve this?
import torch
import numpy as np
multiplier = torch.from_numpy(np.array([1.0, -2.0, 5.0]))
source = torch.from_numpy(
np.array([[0.5, -0.6],
[-3.0, -2.0],
[-4.0, 2.3]]))
target = torch.from_numpy(np.array([0, 1, 0]))
loss_fn = torch.nn.CrossEntropyLoss()
loss_fn(source, target)
Note that the weight argument of the function weighs the various target indices differently, that is not what I want.
>Solution :
You can do it manually by not reducing the loss (not aggregating it by taking an average, and then doing your weighted sum by hand:
loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
loss = torch.sum( batch_weights*loss_fn(source, target)) / torch.sum(batch_weights)
see https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html