Skip to content

bnm_loss

BNMLoss

Bases: torch.nn.Module

Implementation of the loss in Towards Discriminability and Diversity: Batch Nuclear-norm Maximization under Label Insufficient Situations.

Source code in pytorch_adapt\layers\bnm_loss.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
class BNMLoss(torch.nn.Module):
    """
    Implementation of the loss in
    [Towards Discriminability and Diversity:
    Batch Nuclear-norm Maximization
    under Label Insufficient Situations](https://arxiv.org/abs/2003.12237).
    """

    def forward(self, x):
        """"""
        x = torch.nn.functional.softmax(x, dim=1)
        return -torch.linalg.norm(x, "nuc") / x.shape[0]