Bases: BaseWrapperHook
Implementation of
Domain-Symmetric Networks for Adversarial Domain Adaptation.
Source code in pytorch_adapt\hooks\symnets.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100 | class SymNetsHook(BaseWrapperHook):
"""
Implementation of
[Domain-Symmetric Networks for Adversarial Domain Adaptation](https://arxiv.org/abs/1904.04663).
"""
def __init__(
self,
c_opts,
g_opts,
c_weighter=None,
c_reducer=None,
g_weighter=None,
g_reducer=None,
**kwargs,
):
super().__init__(**kwargs)
f_hook = FeaturesHook()
c_hook = OptimizerHook(SymNetsCHook(), c_opts, c_weighter, c_reducer)
g_hook = OptimizerHook(SymNetsGHook(), g_opts, g_weighter, g_reducer)
s_hook = SummaryHook({"c_loss": c_hook, "g_loss": g_hook})
self.hook = ChainHook(f_hook, c_hook, g_hook, s_hook)
|