import torch
from torch.nn.modules.loss import _WeightedLoss

[docs]class LabelSmoothCrossEntropyLoss(_WeightedLoss): """Constructor for cross-entropy loss with label smoothing Parameters ---------- smoothing: float The label smoothing factor. Specify a value between 0 and 1. weight: torch.Tensor The tensor of weights given to each class. reduction: str Specifies the reduction to apply to the output. Specify one of `none`, `sum`, or `mean`. Adapted from """ def __init__(self, weight: torch.Tensor = None, reduction: str = "mean", smoothing=0.0): super().__init__(weight=weight, reduction=reduction) self.smoothing = smoothing self.weight = weight self.reduction = reduction @staticmethod def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing: float = 0.0): assert 0 <= smoothing < 1, f"smoothing factor {smoothing} should be between 0 and 1" with torch.no_grad(): targets = ( torch.empty(size=(targets.size(0), n_classes), device=targets.device) .fill_(smoothing / (n_classes - 1)) .scatter_(1,, 1.0 - smoothing) ) return targets
[docs] def forward(self, inputs, targets): targets = LabelSmoothCrossEntropyLoss._smooth_one_hot( targets, inputs.size(-1), self.smoothing ) lsm = inputs if self.weight is not None: lsm = lsm * self.weight.unsqueeze(0) loss = -(targets * lsm).sum(-1) if self.reduction == "sum": loss = loss.sum() elif self.reduction == "mean": loss = loss.mean() elif self.reduction == "none": loss = loss else: raise ValueError( f"{self.reduction} is not supported, please choose one of the following values" " [`sum`, `none`, `mean`]" ) return loss