Skip to content

rtn

RTNHook

Bases: BaseWrapperHook

Implementation of Unsupervised Domain Adaptation with Residual Transfer Networks.

Source code in pytorch_adapt\hooks\rtn.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class RTNHook(BaseWrapperHook):
    """
    Implementation of
    [Unsupervised Domain Adaptation with Residual Transfer Networks](https://arxiv.org/abs/1602.04433).
    """

    def __init__(
        self,
        opts,
        weighter=None,
        reducer=None,
        pre=None,
        post=None,
        aligner_loss_fn=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        [pre, post] = c_f.many_default([pre, post], [[], []])
        hook = ChainHook(*pre, RTNAlignerHook(aligner_loss_fn), RTNLogitsHook(), *post)
        hook = OptimizerHook(hook, opts, weighter, reducer)
        s_hook = SummaryHook({"total_loss": hook})
        self.hook = ChainHook(hook, s_hook)