Skip to content

domain

DomainLossHook

Bases: BaseWrapperHook

Computes the loss of a discriminator's output with respect to domain labels.

Source code in pytorch_adapt\hooks\domain.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class DomainLossHook(BaseWrapperHook):
    """
    Computes the loss of a discriminator's output with
    respect to domain labels.
    """

    def __init__(
        self,
        d_loss_fn=None,
        detach_features=False,
        reverse_labels=False,
        domains=None,
        f_hook=None,
        d_hook=None,
        **kwargs,
    ):
        """
        Arguments:
            d_loss_fn: The loss applied to the discriminator's logits.
                If ```None``` it defaults to
                ```torch.nn.BCEWithLogitsLoss```.
            detach_features: If ```True```, the input to the
                discriminator will be detached first.
            reverse_labels: If ```True```, the ```"src"``` and
                ```"target"``` domain labels will be swapped.
            domains: The domains to apply the loss to.
                If ```None``` it defaults to ```["src", "target"]```.
            f_hook: The hook for computing the input to the discriminator.
            d_hook: The hook for computing the discriminator logits.
        """
        super().__init__(**kwargs)
        self.d_loss_fn = c_f.default(
            d_loss_fn, torch.nn.BCEWithLogitsLoss, {"reduction": "none"}
        )
        self.reverse_labels = reverse_labels
        self.domains = c_f.default(domains, ["src", "target"])
        f_hook = c_f.default(
            f_hook,
            FeaturesForDomainLossHook,
            {"detach": detach_features, "domains": domains},
        )
        d_hook = c_f.default(d_hook, DLogitsHook, {"domains": domains})
        f_out = f_hook.last_hook_out_keys
        d_hook.set_in_keys(f_out)
        self.check_fhook_dhook_keys(f_hook, d_hook, detach_features)
        self.hook = ChainHook(f_hook, d_hook)
        self.in_keys = self.hook.in_keys + ["src_domain", "target_domain"]

    def call(self, inputs, losses):
        losses = {}
        outputs = self.hook(inputs, losses)[0]
        labels = self.extract_domain_labels(inputs)
        for domain_name, labels in labels.items():
            self.logger(f"Computing loss for {domain_name} domain")
            [dlogits] = c_f.extract(
                [outputs, inputs],
                c_f.filter(self.hook.out_keys, "_dlogits$", [f"^{domain_name}"]),
            )
            if dlogits.dim() > 1:
                labels = labels.type(torch.long)
            else:
                labels = labels.type(torch.float)
            loss = self.d_loss_fn(dlogits, labels)
            losses[f"{domain_name}_domain_loss"] = loss
        return outputs, losses

    def extract_domain_labels(self, inputs):
        self.logger("Expecting 'src_domain' and 'target_domain' in inputs")
        [src_domain, target_domain] = c_f.extract(
            inputs, ["src_domain", "target_domain"]
        )
        if self.reverse_labels:
            labels = {"src": target_domain, "target": src_domain}
        else:
            labels = {"src": src_domain, "target": target_domain}
        return {k: v for k, v in labels.items() if k in self.domains}

    def _loss_keys(self):
        return [f"{x}_domain_loss" for x in self.domains]

    def check_fhook_dhook_keys(self, f_hook, d_hook, detach_features):
        if detach_features and len(
            c_f.filter(f_hook.out_keys, "detached$", self.domains)
        ) < len(self.domains):
            error_str = (
                "detach_features is True, but the number of f_hook's detached outputs "
            )
            error_str += "doesn't match the number of domains."
            error_str += f"\nf_hook's outputs: {f_hook.out_keys}"
            error_str += f"\nfdomains: {self.domains}"
            raise ValueError(error_str)
        for name, keys in [("f_hook", f_hook.out_keys), ("d_hook", d_hook.out_keys)]:
            if not all(
                c_f.filter(keys, f"^{self.domains[i]}")
                for i in range(len(self.domains))
            ):
                raise ValueError(
                    f"domains = {self.domains} but d_hook.out_keys = {d_hook.out_keys}"
                )

    def extra_repr(self):
        return c_f.extra_repr(self, ["reverse_labels"])

__init__(d_loss_fn=None, detach_features=False, reverse_labels=False, domains=None, f_hook=None, d_hook=None, **kwargs)

Parameters:

Name Type Description Default
d_loss_fn

The loss applied to the discriminator's logits. If None it defaults to torch.nn.BCEWithLogitsLoss.

None
detach_features

If True, the input to the discriminator will be detached first.

False
reverse_labels

If True, the "src" and "target" domain labels will be swapped.

False
domains

The domains to apply the loss to. If None it defaults to ["src", "target"].

None
f_hook

The hook for computing the input to the discriminator.

None
d_hook

The hook for computing the discriminator logits.

None
Source code in pytorch_adapt\hooks\domain.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __init__(
    self,
    d_loss_fn=None,
    detach_features=False,
    reverse_labels=False,
    domains=None,
    f_hook=None,
    d_hook=None,
    **kwargs,
):
    """
    Arguments:
        d_loss_fn: The loss applied to the discriminator's logits.
            If ```None``` it defaults to
            ```torch.nn.BCEWithLogitsLoss```.
        detach_features: If ```True```, the input to the
            discriminator will be detached first.
        reverse_labels: If ```True```, the ```"src"``` and
            ```"target"``` domain labels will be swapped.
        domains: The domains to apply the loss to.
            If ```None``` it defaults to ```["src", "target"]```.
        f_hook: The hook for computing the input to the discriminator.
        d_hook: The hook for computing the discriminator logits.
    """
    super().__init__(**kwargs)
    self.d_loss_fn = c_f.default(
        d_loss_fn, torch.nn.BCEWithLogitsLoss, {"reduction": "none"}
    )
    self.reverse_labels = reverse_labels
    self.domains = c_f.default(domains, ["src", "target"])
    f_hook = c_f.default(
        f_hook,
        FeaturesForDomainLossHook,
        {"detach": detach_features, "domains": domains},
    )
    d_hook = c_f.default(d_hook, DLogitsHook, {"domains": domains})
    f_out = f_hook.last_hook_out_keys
    d_hook.set_in_keys(f_out)
    self.check_fhook_dhook_keys(f_hook, d_hook, detach_features)
    self.hook = ChainHook(f_hook, d_hook)
    self.in_keys = self.hook.in_keys + ["src_domain", "target_domain"]

FeaturesForDomainLossHook

Bases: FeaturesChainHook

A FeaturesChainHook that has options specific to DomainLossHook.

Source code in pytorch_adapt\hooks\domain.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class FeaturesForDomainLossHook(FeaturesChainHook):
    """
    A [```FeaturesChainHook```][pytorch_adapt.hooks.features.FeaturesChainHook]
    that has options specific to
    [```DomainLossHook```][pytorch_adapt.hooks.DomainLossHook].
    """

    def __init__(
        self,
        f_hook=None,
        l_hook=None,
        use_logits=False,
        domains=None,
        detach=False,
        **kwargs,
    ):
        """
        Arguments:
            f_hook: hook for computing features
            l_hook: hook for computing logits. This will be used
                only if ```use_logits``` is ```True```.
            use_logits: If ```True```, the logits hook is executed
                after the features hook.
            domains: the domains for which features will be computed.
            detach: If ```True```, all outputs will be detached
                from the autograd graph.
        """
        hooks = [
            c_f.default(
                f_hook,
                FeaturesHook(detach=detach, domains=domains),
            )
        ]
        if use_logits:
            hooks.append(
                c_f.default(l_hook, LogitsHook(detach=detach, domains=domains))
            )
        super().__init__(*hooks, **kwargs)

__init__(f_hook=None, l_hook=None, use_logits=False, domains=None, detach=False, **kwargs)

Parameters:

Name Type Description Default
f_hook

hook for computing features

None
l_hook

hook for computing logits. This will be used only if use_logits is True.

None
use_logits

If True, the logits hook is executed after the features hook.

False
domains

the domains for which features will be computed.

None
detach

If True, all outputs will be detached from the autograd graph.

False
Source code in pytorch_adapt\hooks\domain.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    f_hook=None,
    l_hook=None,
    use_logits=False,
    domains=None,
    detach=False,
    **kwargs,
):
    """
    Arguments:
        f_hook: hook for computing features
        l_hook: hook for computing logits. This will be used
            only if ```use_logits``` is ```True```.
        use_logits: If ```True```, the logits hook is executed
            after the features hook.
        domains: the domains for which features will be computed.
        detach: If ```True```, all outputs will be detached
            from the autograd graph.
    """
    hooks = [
        c_f.default(
            f_hook,
            FeaturesHook(detach=detach, domains=domains),
        )
    ]
    if use_logits:
        hooks.append(
            c_f.default(l_hook, LogitsHook(detach=detach, domains=domains))
        )
    super().__init__(*hooks, **kwargs)