import torch
from torch.nn.modules.loss import _WeightedLoss
[docs]class LabelSmoothCrossEntropyLoss(_WeightedLoss):
"""Constructor for cross-entropy loss with label smoothing
Parameters
----------
smoothing: float
The label smoothing factor. Specify a value between 0 and 1.
weight: torch.Tensor
The tensor of weights given to each class.
reduction: str
Specifies the reduction to apply to the output.
Specify one of `none`, `sum`, or `mean`.
Adapted from https://github.com/NingAnMe/Label-Smoothing-for-CrossEntropyLoss-PyTorch
"""
def __init__(self, weight: torch.Tensor = None, reduction: str = "mean", smoothing=0.0):
super().__init__(weight=weight, reduction=reduction)
self.smoothing = smoothing
self.weight = weight
self.reduction = reduction
@staticmethod
def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing: float = 0.0):
assert 0 <= smoothing < 1, f"smoothing factor {smoothing} should be between 0 and 1"
with torch.no_grad():
targets = (
torch.empty(size=(targets.size(0), n_classes), device=targets.device)
.fill_(smoothing / (n_classes - 1))
.scatter_(1, targets.data.unsqueeze(1).to(torch.int64), 1.0 - smoothing)
)
return targets
[docs] def forward(self, inputs, targets):
targets = LabelSmoothCrossEntropyLoss._smooth_one_hot(
targets, inputs.size(-1), self.smoothing
)
lsm = inputs
if self.weight is not None:
lsm = lsm * self.weight.unsqueeze(0)
loss = -(targets * lsm).sum(-1)
if self.reduction == "sum":
loss = loss.sum()
elif self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "none":
loss = loss
else:
raise ValueError(
f"{self.reduction} is not supported, please choose one of the following values"
" [`sum`, `none`, `mean`]"
)
return loss