Skip to content

nll_loss

NLLLoss

Bases: torch.nn.Module

Same as torch.nn.NLLLoss but takes in softmax as input

Source code in pytorch_adapt\layers\nll_loss.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class NLLLoss(torch.nn.Module):
    """
    Same as torch.nn.NLLLoss but takes in softmax as input
    """

    def __init__(self, reduction="mean"):
        super().__init__()
        self.reduction = reduction

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        """ """
        x = torch.log(x + pml_cf.small_val(x.dtype))
        return torch.nn.functional.nll_loss(x, y, reduction=self.reduction)