Skip to content

ist_loss

ISTLoss

Bases: torch.nn.Module

Implementation of the I_st loss from Information-Theoretical Learning of Discriminative Clusters for Unsupervised Domain Adaptation

Source code in pytorch_adapt\layers\ist_loss.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class ISTLoss(torch.nn.Module):
    """
    Implementation of the I_st loss from
    [Information-Theoretical Learning of Discriminative Clusters for Unsupervised Domain Adaptation](https://icml.cc/2012/papers/566.pdf)
    """

    def __init__(self, distance=None, with_ent=True, with_div=True):
        super().__init__()
        self.distance = c_f.default(distance, CosineSimilarity, {})
        if not (with_ent or with_div):
            raise ValueError("At least one of with_ent or with_div must be True")
        self.with_ent = with_ent
        self.with_div = with_div
        self.ent_loss_fn = EntropyLoss(after_softmax=True)
        self.div_loss_fn = DiversityLoss(after_softmax=True)

    def forward(self, x, y):
        """
        Arguments:
            x: source and target features
            y: domain labels, i.e. 0 for source domain, 1 for target domain
        """
        n = x.shape[0]
        if torch.min(y) < 0 or torch.max(y) > 1:
            raise ValueError("y must be in the range 0 and 1")
        if y.shape != torch.Size([n]):
            raise TypeError("y must have shape (N,)")

        mat = self.distance(x)
        # remove self comparisons
        mask = ~torch.eye(n, dtype=torch.bool)
        mat = mat[mask].view(n, n - 1)
        probs = get_probs(mat, mask, y, self.distance.is_inverted)

        return get_loss(
            probs, self.ent_loss_fn, self.div_loss_fn, self.with_ent, self.with_div
        )

    def extra_repr(self):
        """"""
        return c_f.extra_repr(self, ["with_div"])

forward(x, y)

Parameters:

Name Type Description Default
x

source and target features

required
y

domain labels, i.e. 0 for source domain, 1 for target domain

required
Source code in pytorch_adapt\layers\ist_loss.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def forward(self, x, y):
    """
    Arguments:
        x: source and target features
        y: domain labels, i.e. 0 for source domain, 1 for target domain
    """
    n = x.shape[0]
    if torch.min(y) < 0 or torch.max(y) > 1:
        raise ValueError("y must be in the range 0 and 1")
    if y.shape != torch.Size([n]):
        raise TypeError("y must have shape (N,)")

    mat = self.distance(x)
    # remove self comparisons
    mask = ~torch.eye(n, dtype=torch.bool)
    mat = mat[mask].view(n, n - 1)
    probs = get_probs(mat, mask, y, self.distance.is_inverted)

    return get_loss(
        probs, self.ent_loss_fn, self.div_loss_fn, self.with_ent, self.with_div
    )