Skip to content

diversity_loss

DiversityLoss

Bases: torch.nn.Module

Encourages predictions to be uniform, batch wise. Takes logits (before softmax) as input.

For example:

  • A tensor with a large loss: torch.tensor([[1e4, 0, 0], [1e4, 0, 0], [1e4, 0, 0]])

  • A tensor with a small loss: torch.tensor([[1e4, 0, 0], [0, 1e4, 0], [0, 0, 1e4]])

Source code in pytorch_adapt\layers\diversity_loss.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class DiversityLoss(torch.nn.Module):
    """
    Encourages predictions to be uniform, batch wise.
    Takes logits (before softmax) as input.

    For example:

    - A tensor with a large loss: ```torch.tensor([[1e4, 0, 0], [1e4, 0, 0], [1e4, 0, 0]])```

    - A tensor with a small loss: ```torch.tensor([[1e4, 0, 0], [0, 1e4, 0], [0, 0, 1e4]])```
    """

    def __init__(self, after_softmax: bool = False):
        """
        Arguments:
            after_softmax: If ```True```, then the rows of the input are assumed to
                already have softmax applied to them.
        """
        super().__init__()
        self.after_softmax = after_softmax

    def forward(self, logits):
        """"""
        if not self.after_softmax:
            logits = torch.softmax(logits, dim=1)
        logits = torch.mean(logits, dim=0, keepdim=True)
        return -torch.mean(get_entropy(logits, after_softmax=True))

    def extra_repr(self):
        """"""
        return c_f.extra_repr(self, ["after_softmax"])

__init__(after_softmax=False)

Parameters:

Name Type Description Default
after_softmax bool

If True, then the rows of the input are assumed to already have softmax applied to them.

False
Source code in pytorch_adapt\layers\diversity_loss.py
19
20
21
22
23
24
25
26
def __init__(self, after_softmax: bool = False):
    """
    Arguments:
        after_softmax: If ```True```, then the rows of the input are assumed to
            already have softmax applied to them.
    """
    super().__init__()
    self.after_softmax = after_softmax