Skip to content

symnets

SymNetsHook

Bases: BaseWrapperHook

Implementation of Domain-Symmetric Networks for Adversarial Domain Adaptation.

Source code in pytorch_adapt\hooks\symnets.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class SymNetsHook(BaseWrapperHook):
    """
    Implementation of
    [Domain-Symmetric Networks for Adversarial Domain Adaptation](https://arxiv.org/abs/1904.04663).
    """

    def __init__(
        self,
        c_opts,
        g_opts,
        c_weighter=None,
        c_reducer=None,
        g_weighter=None,
        g_reducer=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        f_hook = FeaturesHook()
        c_hook = OptimizerHook(SymNetsCHook(), c_opts, c_weighter, c_reducer)
        g_hook = OptimizerHook(SymNetsGHook(), g_opts, g_weighter, g_reducer)
        s_hook = SummaryHook({"c_loss": c_hook, "g_loss": g_hook})
        self.hook = ChainHook(f_hook, c_hook, g_hook, s_hook)