Skip to content

mcc_loss

MCCLoss

Bases: torch.nn.Module

Implementation of Minimum Class Confusion for Versatile Domain Adaptation.

Source code in pytorch_adapt\layers\mcc_loss.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class MCCLoss(torch.nn.Module):
    """
    Implementation of
    [Minimum Class Confusion for Versatile Domain Adaptation](https://arxiv.org/abs/1912.03699).
    """

    def __init__(
        self,
        T: float = 1,
        entropy_weighter: Callable[[torch.Tensor], torch.Tensor] = None,
    ):
        """
        Arguments:
            T: softmax temperature applied to the input target logits
            entropy_weighter: a function that returns a weight for each
                sample. The weights are used in the process of computing
                the class confusion tensor as described in the paper.
                If ```None```, then ```layers.EntropyWeights``` is used.
        """
        super().__init__()
        self.T = T
        self.entropy_weighter = c_f.default(
            entropy_weighter,
            EntropyWeights(
                after_softmax=True, normalizer=SumNormalizer(scale_by_batch_size=True)
            ),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: target logits
        """
        Y = torch.nn.functional.softmax(x / self.T, dim=1)
        H_weights = self.entropy_weighter(Y.detach())
        C = torch.mm((Y * H_weights.view(-1, 1)).t(), Y)
        C = C / torch.sum(C, dim=1)
        return (torch.sum(C) - torch.trace(C)) / C.shape[0]

    def extra_repr(self):
        """"""
        return c_f.extra_repr(self, ["T"])

__init__(T=1, entropy_weighter=None)

Parameters:

Name Type Description Default
T float

softmax temperature applied to the input target logits

1
entropy_weighter Callable[[torch.Tensor], torch.Tensor]

a function that returns a weight for each sample. The weights are used in the process of computing the class confusion tensor as described in the paper. If None, then layers.EntropyWeights is used.

None
Source code in pytorch_adapt\layers\mcc_loss.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(
    self,
    T: float = 1,
    entropy_weighter: Callable[[torch.Tensor], torch.Tensor] = None,
):
    """
    Arguments:
        T: softmax temperature applied to the input target logits
        entropy_weighter: a function that returns a weight for each
            sample. The weights are used in the process of computing
            the class confusion tensor as described in the paper.
            If ```None```, then ```layers.EntropyWeights``` is used.
    """
    super().__init__()
    self.T = T
    self.entropy_weighter = c_f.default(
        entropy_weighter,
        EntropyWeights(
            after_softmax=True, normalizer=SumNormalizer(scale_by_batch_size=True)
        ),
    )

forward(x)

Parameters:

Name Type Description Default
x torch.Tensor

target logits

required
Source code in pytorch_adapt\layers\mcc_loss.py
39
40
41
42
43
44
45
46
47
48
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Arguments:
        x: target logits
    """
    Y = torch.nn.functional.softmax(x / self.T, dim=1)
    H_weights = self.entropy_weighter(Y.detach())
    C = torch.mm((Y * H_weights.view(-1, 1)).t(), Y)
    C = C / torch.sum(C, dim=1)
    return (torch.sum(C) - torch.trace(C)) / C.shape[0]