Bases: torch.nn.Module
Implementation of the I_st loss from
Information-Theoretical Learning of Discriminative Clusters for Unsupervised Domain Adaptation
Source code in pytorch_adapt\layers\ist_loss.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71 | class ISTLoss(torch.nn.Module):
"""
Implementation of the I_st loss from
[Information-Theoretical Learning of Discriminative Clusters for Unsupervised Domain Adaptation](https://icml.cc/2012/papers/566.pdf)
"""
def __init__(self, distance=None, with_ent=True, with_div=True):
super().__init__()
self.distance = c_f.default(distance, CosineSimilarity, {})
if not (with_ent or with_div):
raise ValueError("At least one of with_ent or with_div must be True")
self.with_ent = with_ent
self.with_div = with_div
self.ent_loss_fn = EntropyLoss(after_softmax=True)
self.div_loss_fn = DiversityLoss(after_softmax=True)
def forward(self, x, y):
"""
Arguments:
x: source and target features
y: domain labels, i.e. 0 for source domain, 1 for target domain
"""
n = x.shape[0]
if torch.min(y) < 0 or torch.max(y) > 1:
raise ValueError("y must be in the range 0 and 1")
if y.shape != torch.Size([n]):
raise TypeError("y must have shape (N,)")
mat = self.distance(x)
# remove self comparisons
mask = ~torch.eye(n, dtype=torch.bool)
mat = mat[mask].view(n, n - 1)
probs = get_probs(mat, mask, y, self.distance.is_inverted)
return get_loss(
probs, self.ent_loss_fn, self.div_loss_fn, self.with_ent, self.with_div
)
def extra_repr(self):
""""""
return c_f.extra_repr(self, ["with_div"])
|
forward(x, y)
Parameters:
Name |
Type |
Description |
Default |
x |
|
source and target features |
required
|
y |
|
domain labels, i.e. 0 for source domain, 1 for target domain |
required
|
Source code in pytorch_adapt\layers\ist_loss.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67 | def forward(self, x, y):
"""
Arguments:
x: source and target features
y: domain labels, i.e. 0 for source domain, 1 for target domain
"""
n = x.shape[0]
if torch.min(y) < 0 or torch.max(y) > 1:
raise ValueError("y must be in the range 0 and 1")
if y.shape != torch.Size([n]):
raise TypeError("y must have shape (N,)")
mat = self.distance(x)
# remove self comparisons
mask = ~torch.eye(n, dtype=torch.bool)
mat = mat[mask].view(n, n - 1)
probs = get_probs(mat, mask, y, self.distance.is_inverted)
return get_loss(
probs, self.ent_loss_fn, self.div_loss_fn, self.with_ent, self.with_div
)
|