Skip to content

dann

DANNHook

Bases: BaseWrapperHook

Implementation of Domain-Adversarial Training of Neural Networks.

This includes the model optimization step.

Source code in pytorch_adapt\hooks\dann.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
class DANNHook(BaseWrapperHook):
    """
    Implementation of
    [Domain-Adversarial Training of Neural Networks](https://arxiv.org/abs/1505.07818).

    This includes the model optimization step.
    """

    def __init__(
        self,
        opts,
        weighter=None,
        reducer=None,
        pre=None,
        pre_d=None,
        post_d=None,
        pre_g=None,
        post_g=None,
        gradient_reversal=None,
        gradient_reversal_weight=1,
        use_logits=False,
        f_hook=None,
        d_hook=None,
        c_hook=None,
        domain_loss_hook=None,
        d_hook_allowed="_dlogits$",
        **kwargs
    ):
        """
        Arguments:
            opts: List of optimizers for updating the models.
            weighter: Weights the losses before backpropagation.
                If ```None``` then it defaults to
                [```MeanWeighter```][pytorch_adapt.weighters.MeanWeighter]
            reducer: Reduces loss tensors.
                If ```None``` then it defaults to
                [```MeanReducer```][pytorch_adapt.hooks.MeanReducer]
            pre: List of hooks that will be executed at the very
                beginning of each iteration.
            pre_d: List of hooks that will be executed after
                gradient reversal, but before the domain loss.
            post_d: List of hooks that will be executed after
                gradient reversal, and after the domain loss.
            pre_g: List of hooks that will be executed outside of
                the gradient reversal step, and before the generator
                and classifier loss.
            post_g: List of hooks that will be executed after
                the generator and classifier losses.
            gradient_reversal: Called before all D hooks, including
                ```pre_d```.
            use_logits: If ```True```, then D receives the output of C
                instead of the output of G.
            f_hook: The hook used for computing features and logits.
                If ```None``` then it defaults to
                [```FeaturesForDomainLossHook```][pytorch_adapt.hooks.FeaturesForDomainLossHook]
            d_hook: The hook used for computing discriminator logits.
                If ```None``` then it defaults to
                [```DLogitsHook```][pytorch_adapt.hooks.DLogitsHook]
            c_hook: The hook used for computing the classifiers's loss.
                If ```None``` then it defaults to
                [```CLossHook```][pytorch_adapt.hooks.CLossHook]
            domain_loss_hook: The hook used for computing the domain loss.
                If ```None``` then it defaults to
                [```DomainLossHook```][pytorch_adapt.hooks.DomainLossHook].
            d_hook_allowed: A regex string that specifies the allowed
                output names of the discriminator block.
        """
        super().__init__(**kwargs)
        [pre, pre_d, post_d, pre_g, post_g] = c_f.many_default(
            [pre, pre_d, post_d, pre_g, post_g], [[], [], [], [], []]
        )
        f_hook = c_f.default(
            f_hook, FeaturesForDomainLossHook, {"use_logits": use_logits}
        )
        gradient_reversal = c_f.default(
            gradient_reversal,
            GradientReversalHook,
            {"weight": gradient_reversal_weight, "apply_to": f_hook.out_keys},
        )
        c_hook = c_f.default(c_hook, CLossHook, {})
        domain_loss_hook = c_f.default(
            domain_loss_hook, DomainLossHook, {"f_hook": f_hook, "d_hook": d_hook}
        )

        disc_hook = AssertHook(
            OnlyNewOutputsHook(
                ChainHook(
                    gradient_reversal,
                    *pre_d,
                    domain_loss_hook,
                    *post_d,
                    overwrite=[1],
                )
            ),
            d_hook_allowed,
        )
        gen_hook = ChainHook(*pre_g, c_hook, *post_g)

        hook = ChainHook(*pre, f_hook, disc_hook, gen_hook)
        hook = OptimizerHook(hook, opts, weighter, reducer)
        s_hook = SummaryHook({"total_loss": hook})
        self.hook = ChainHook(hook, s_hook)

__init__(opts, weighter=None, reducer=None, pre=None, pre_d=None, post_d=None, pre_g=None, post_g=None, gradient_reversal=None, gradient_reversal_weight=1, use_logits=False, f_hook=None, d_hook=None, c_hook=None, domain_loss_hook=None, d_hook_allowed='_dlogits$', **kwargs)

Parameters:

Name Type Description Default
opts

List of optimizers for updating the models.

required
weighter

Weights the losses before backpropagation. If None then it defaults to MeanWeighter

None
reducer

Reduces loss tensors. If None then it defaults to MeanReducer

None
pre

List of hooks that will be executed at the very beginning of each iteration.

None
pre_d

List of hooks that will be executed after gradient reversal, but before the domain loss.

None
post_d

List of hooks that will be executed after gradient reversal, and after the domain loss.

None
pre_g

List of hooks that will be executed outside of the gradient reversal step, and before the generator and classifier loss.

None
post_g

List of hooks that will be executed after the generator and classifier losses.

None
gradient_reversal

Called before all D hooks, including pre_d.

None
use_logits

If True, then D receives the output of C instead of the output of G.

False
f_hook

The hook used for computing features and logits. If None then it defaults to FeaturesForDomainLossHook

None
d_hook

The hook used for computing discriminator logits. If None then it defaults to DLogitsHook

None
c_hook

The hook used for computing the classifiers's loss. If None then it defaults to CLossHook

None
domain_loss_hook

The hook used for computing the domain loss. If None then it defaults to DomainLossHook.

None
d_hook_allowed

A regex string that specifies the allowed output names of the discriminator block.

'_dlogits$'
Source code in pytorch_adapt\hooks\dann.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def __init__(
    self,
    opts,
    weighter=None,
    reducer=None,
    pre=None,
    pre_d=None,
    post_d=None,
    pre_g=None,
    post_g=None,
    gradient_reversal=None,
    gradient_reversal_weight=1,
    use_logits=False,
    f_hook=None,
    d_hook=None,
    c_hook=None,
    domain_loss_hook=None,
    d_hook_allowed="_dlogits$",
    **kwargs
):
    """
    Arguments:
        opts: List of optimizers for updating the models.
        weighter: Weights the losses before backpropagation.
            If ```None``` then it defaults to
            [```MeanWeighter```][pytorch_adapt.weighters.MeanWeighter]
        reducer: Reduces loss tensors.
            If ```None``` then it defaults to
            [```MeanReducer```][pytorch_adapt.hooks.MeanReducer]
        pre: List of hooks that will be executed at the very
            beginning of each iteration.
        pre_d: List of hooks that will be executed after
            gradient reversal, but before the domain loss.
        post_d: List of hooks that will be executed after
            gradient reversal, and after the domain loss.
        pre_g: List of hooks that will be executed outside of
            the gradient reversal step, and before the generator
            and classifier loss.
        post_g: List of hooks that will be executed after
            the generator and classifier losses.
        gradient_reversal: Called before all D hooks, including
            ```pre_d```.
        use_logits: If ```True```, then D receives the output of C
            instead of the output of G.
        f_hook: The hook used for computing features and logits.
            If ```None``` then it defaults to
            [```FeaturesForDomainLossHook```][pytorch_adapt.hooks.FeaturesForDomainLossHook]
        d_hook: The hook used for computing discriminator logits.
            If ```None``` then it defaults to
            [```DLogitsHook```][pytorch_adapt.hooks.DLogitsHook]
        c_hook: The hook used for computing the classifiers's loss.
            If ```None``` then it defaults to
            [```CLossHook```][pytorch_adapt.hooks.CLossHook]
        domain_loss_hook: The hook used for computing the domain loss.
            If ```None``` then it defaults to
            [```DomainLossHook```][pytorch_adapt.hooks.DomainLossHook].
        d_hook_allowed: A regex string that specifies the allowed
            output names of the discriminator block.
    """
    super().__init__(**kwargs)
    [pre, pre_d, post_d, pre_g, post_g] = c_f.many_default(
        [pre, pre_d, post_d, pre_g, post_g], [[], [], [], [], []]
    )
    f_hook = c_f.default(
        f_hook, FeaturesForDomainLossHook, {"use_logits": use_logits}
    )
    gradient_reversal = c_f.default(
        gradient_reversal,
        GradientReversalHook,
        {"weight": gradient_reversal_weight, "apply_to": f_hook.out_keys},
    )
    c_hook = c_f.default(c_hook, CLossHook, {})
    domain_loss_hook = c_f.default(
        domain_loss_hook, DomainLossHook, {"f_hook": f_hook, "d_hook": d_hook}
    )

    disc_hook = AssertHook(
        OnlyNewOutputsHook(
            ChainHook(
                gradient_reversal,
                *pre_d,
                domain_loss_hook,
                *post_d,
                overwrite=[1],
            )
        ),
        d_hook_allowed,
    )
    gen_hook = ChainHook(*pre_g, c_hook, *post_g)

    hook = ChainHook(*pre, f_hook, disc_hook, gen_hook)
    hook = OptimizerHook(hook, opts, weighter, reducer)
    s_hook = SummaryHook({"total_loss": hook})
    self.hook = ChainHook(hook, s_hook)