Skip to content

gvb

GVBHook

Bases: DANNHook

Implementation of Gradually Vanishing Bridge for Adversarial Domain Adaptation

Source code in pytorch_adapt\hooks\gvb.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class GVBHook(DANNHook):
    """
    Implementation of
    [Gradually Vanishing Bridge for Adversarial Domain Adaptation](https://arxiv.org/abs/2003.13183)
    """

    def __init__(
        self, gradient_reversal_weight=1, pre=None, pre_d=None, pre_g=None, **kwargs
    ):
        # f_hook and d_hook are used inside DomainLossHook
        f_hook = FeaturesForDomainLossHook(use_logits=True)
        d_hook = DBridgeAndLogitsHook()
        apply_to = c_f.filter(f_hook.out_keys, "_logits$")
        gradient_reversal = SoftmaxGradientReversalHook(
            weight=gradient_reversal_weight, apply_to=apply_to
        )
        [pre, pre_d, pre_g] = c_f.many_default([pre, pre_d, pre_g], [[], [], []])
        pre += [FeaturesLogitsAndGBridge()]
        pre_d += [DBridgeLossHook()]
        pre_g += [GBridgeLossHook()]

        super().__init__(
            pre=pre,
            pre_d=pre_d,
            pre_g=pre_g,
            gradient_reversal=gradient_reversal,
            f_hook=f_hook,
            d_hook=d_hook,
            d_hook_allowed="_dlogits$|_dbridge$",
            **kwargs,
        )