Skip to content

classification

CLossHook

Bases: BaseWrapperHook

Computes a classification loss on the specified tensors. The default setting is to compute the cross entropy loss of the source domain logits.

Source code in pytorch_adapt\hooks\classification.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class CLossHook(BaseWrapperHook):
    """
    Computes a classification loss on the specified tensors.
    The default setting is to compute the cross entropy loss
    of the source domain logits.
    """

    def __init__(
        self,
        loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
        detach_features: bool = False,
        f_hook: BaseHook = None,
        **kwargs,
    ):
        """
        Arguments:
            loss_fn: The classification loss function. If ```None```,
                it defaults to ```torch.nn.CrossEntropyLoss```.
            detach_features: Whether or not to detach the features,
                from which logits are computed.
            f_hook: The hook for computing logits.
        """

        super().__init__(**kwargs)
        self.loss_fn = c_f.default(
            loss_fn, torch.nn.CrossEntropyLoss, {"reduction": "none"}
        )
        self.hook = c_f.default(
            f_hook,
            FeaturesAndLogitsHook,
            {"domains": ["src"], "detach_features": detach_features},
        )

    def call(self, inputs, losses):
        """"""
        outputs = self.hook(inputs, losses)[0]
        [src_logits] = c_f.extract(
            [outputs, inputs], c_f.filter(self.hook.out_keys, "_logits$")
        )
        loss = self.loss_fn(src_logits, inputs["src_labels"])
        return outputs, {self._loss_keys()[0]: loss}

    def _loss_keys(self):
        """"""
        return ["c_loss"]

__init__(loss_fn=None, detach_features=False, f_hook=None, **kwargs)

Parameters:

Name Type Description Default
loss_fn Callable[[torch.Tensor, torch.Tensor], torch.Tensor]

The classification loss function. If None, it defaults to torch.nn.CrossEntropyLoss.

None
detach_features bool

Whether or not to detach the features, from which logits are computed.

False
f_hook BaseHook

The hook for computing logits.

None
Source code in pytorch_adapt\hooks\classification.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
    detach_features: bool = False,
    f_hook: BaseHook = None,
    **kwargs,
):
    """
    Arguments:
        loss_fn: The classification loss function. If ```None```,
            it defaults to ```torch.nn.CrossEntropyLoss```.
        detach_features: Whether or not to detach the features,
            from which logits are computed.
        f_hook: The hook for computing logits.
    """

    super().__init__(**kwargs)
    self.loss_fn = c_f.default(
        loss_fn, torch.nn.CrossEntropyLoss, {"reduction": "none"}
    )
    self.hook = c_f.default(
        f_hook,
        FeaturesAndLogitsHook,
        {"domains": ["src"], "detach_features": detach_features},
    )

ClassifierHook

Bases: BaseWrapperHook

This computes the classification loss and also optimizes the models.

Source code in pytorch_adapt\hooks\classification.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class ClassifierHook(BaseWrapperHook):
    """
    This computes the classification loss and also
    optimizes the models.
    """

    def __init__(
        self,
        opts,
        weighter=None,
        reducer=None,
        loss_fn=None,
        f_hook=None,
        detach_features=False,
        pre=None,
        post=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        [pre, post] = c_f.many_default([pre, post], [[], []])
        hook = CLossHook(loss_fn, detach_features, f_hook)
        hook = ChainHook(*pre, hook, *post)
        hook = OptimizerHook(hook, opts, weighter, reducer)
        s_hook = SummaryHook({"total_loss": hook})
        self.hook = ChainHook(hook, s_hook)

FinetunerHook

Bases: ClassifierHook

This is the same as ClassifierHook, but it freezes the generator model ("G").

Source code in pytorch_adapt\hooks\classification.py
122
123
124
125
126
127
128
129
130
131
132
class FinetunerHook(ClassifierHook):
    """
    This is the same as
    [```ClassifierHook```][pytorch_adapt.hooks.ClassifierHook],
    but it freezes the generator model ("G").
    """

    def __init__(self, **kwargs):
        f_hook = FrozenModelHook(FeaturesHook(detach=True, domains=["src"]), "G")
        f_hook = FeaturesChainHook(f_hook, LogitsHook(domains=["src"]))
        super().__init__(f_hook=f_hook, **kwargs)

SoftmaxHook

Bases: ApplyFnHook

Applies torch.nn.Softmax(dim=1) to the specified inputs.

Source code in pytorch_adapt\hooks\classification.py
18
19
20
21
22
23
24
25
class SoftmaxHook(ApplyFnHook):
    """
    Applies ```torch.nn.Softmax(dim=1)``` to the
    specified inputs.
    """

    def __init__(self, **kwargs):
        super().__init__(fn=torch.nn.Softmax(dim=1), **kwargs)

SoftmaxLocallyHook

Bases: BaseWrapperHook

Applies torch.nn.Softmax(dim=1) to the specified inputs, which are overwritten, but only inside this hook.

Source code in pytorch_adapt\hooks\classification.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class SoftmaxLocallyHook(BaseWrapperHook):
    """
    Applies ```torch.nn.Softmax(dim=1)``` to the
    specified inputs, which are overwritten, but
    only inside this hook.
    """

    def __init__(self, apply_to: List[str], *hooks: BaseHook, **kwargs):
        """
        Arguments:
            apply_to: list of names of tensors that softmax
                will be applied to.
            *hooks: the hooks that will receive the softmaxed
                tensors.
        """
        super().__init__(**kwargs)
        s_hook = SoftmaxHook(apply_to=apply_to)
        self.hook = OnlyNewOutputsHook(ChainHook(s_hook, *hooks, overwrite=True))

__init__(apply_to, *hooks, **kwargs)

Parameters:

Name Type Description Default
apply_to List[str]

list of names of tensors that softmax will be applied to.

required
*hooks BaseHook

the hooks that will receive the softmaxed tensors.

()
Source code in pytorch_adapt\hooks\classification.py
35
36
37
38
39
40
41
42
43
44
45
def __init__(self, apply_to: List[str], *hooks: BaseHook, **kwargs):
    """
    Arguments:
        apply_to: list of names of tensors that softmax
            will be applied to.
        *hooks: the hooks that will receive the softmaxed
            tensors.
    """
    super().__init__(**kwargs)
    s_hook = SoftmaxHook(apply_to=apply_to)
    self.hook = OnlyNewOutputsHook(ChainHook(s_hook, *hooks, overwrite=True))