class GVBHook(DANNHook):
"""
Implementation of
[Gradually Vanishing Bridge for Adversarial Domain Adaptation](https://arxiv.org/abs/2003.13183)
"""
def __init__(
self, gradient_reversal_weight=1, pre=None, pre_d=None, pre_g=None, **kwargs
):
# f_hook and d_hook are used inside DomainLossHook
f_hook = FeaturesForDomainLossHook(use_logits=True)
d_hook = DBridgeAndLogitsHook()
apply_to = c_f.filter(f_hook.out_keys, "_logits$")
gradient_reversal = SoftmaxGradientReversalHook(
weight=gradient_reversal_weight, apply_to=apply_to
)
[pre, pre_d, pre_g] = c_f.many_default([pre, pre_d, pre_g], [[], [], []])
pre += [FeaturesLogitsAndGBridge()]
pre_d += [DBridgeLossHook()]
pre_g += [GBridgeLossHook()]
super().__init__(
pre=pre,
pre_d=pre_d,
pre_g=pre_g,
gradient_reversal=gradient_reversal,
f_hook=f_hook,
d_hook=d_hook,
d_hook_allowed="_dlogits$|_dbridge$",
**kwargs,
)