Bases: torch.nn.Module
Implementation of
Minimum Class Confusion for Versatile Domain Adaptation.
Source code in pytorch_adapt\layers\mcc_loss.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 | class MCCLoss(torch.nn.Module):
"""
Implementation of
[Minimum Class Confusion for Versatile Domain Adaptation](https://arxiv.org/abs/1912.03699).
"""
def __init__(
self,
T: float = 1,
entropy_weighter: Callable[[torch.Tensor], torch.Tensor] = None,
):
"""
Arguments:
T: softmax temperature applied to the input target logits
entropy_weighter: a function that returns a weight for each
sample. The weights are used in the process of computing
the class confusion tensor as described in the paper.
If ```None```, then ```layers.EntropyWeights``` is used.
"""
super().__init__()
self.T = T
self.entropy_weighter = c_f.default(
entropy_weighter,
EntropyWeights(
after_softmax=True, normalizer=SumNormalizer(scale_by_batch_size=True)
),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Arguments:
x: target logits
"""
Y = torch.nn.functional.softmax(x / self.T, dim=1)
H_weights = self.entropy_weighter(Y.detach())
C = torch.mm((Y * H_weights.view(-1, 1)).t(), Y)
C = C / torch.sum(C, dim=1)
return (torch.sum(C) - torch.trace(C)) / C.shape[0]
def extra_repr(self):
""""""
return c_f.extra_repr(self, ["T"])
|
__init__(T=1, entropy_weighter=None)
Parameters:
Name |
Type |
Description |
Default |
T |
float
|
softmax temperature applied to the input target logits |
1
|
entropy_weighter |
Callable[[torch.Tensor], torch.Tensor]
|
a function that returns a weight for each
sample. The weights are used in the process of computing
the class confusion tensor as described in the paper.
If None , then layers.EntropyWeights is used. |
None
|
Source code in pytorch_adapt\layers\mcc_loss.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 | def __init__(
self,
T: float = 1,
entropy_weighter: Callable[[torch.Tensor], torch.Tensor] = None,
):
"""
Arguments:
T: softmax temperature applied to the input target logits
entropy_weighter: a function that returns a weight for each
sample. The weights are used in the process of computing
the class confusion tensor as described in the paper.
If ```None```, then ```layers.EntropyWeights``` is used.
"""
super().__init__()
self.T = T
self.entropy_weighter = c_f.default(
entropy_weighter,
EntropyWeights(
after_softmax=True, normalizer=SumNormalizer(scale_by_batch_size=True)
),
)
|
forward(x)
Parameters:
Name |
Type |
Description |
Default |
x |
torch.Tensor
|
target logits |
required
|
Source code in pytorch_adapt\layers\mcc_loss.py
39
40
41
42
43
44
45
46
47
48 | def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Arguments:
x: target logits
"""
Y = torch.nn.functional.softmax(x / self.T, dim=1)
H_weights = self.entropy_weighter(Y.detach())
C = torch.mm((Y * H_weights.view(-1, 1)).t(), Y)
C = C / torch.sum(C, dim=1)
return (torch.sum(C) - torch.trace(C)) / C.shape[0]
|