Bases: BaseWrapperHook
Implementation of
Unsupervised Domain Adaptation with Residual Transfer Networks.
Source code in pytorch_adapt\hooks\rtn.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 | class RTNHook(BaseWrapperHook):
"""
Implementation of
[Unsupervised Domain Adaptation with Residual Transfer Networks](https://arxiv.org/abs/1602.04433).
"""
def __init__(
self,
opts,
weighter=None,
reducer=None,
pre=None,
post=None,
aligner_loss_fn=None,
**kwargs,
):
super().__init__(**kwargs)
[pre, post] = c_f.many_default([pre, post], [[], []])
hook = ChainHook(*pre, RTNAlignerHook(aligner_loss_fn), RTNLogitsHook(), *post)
hook = OptimizerHook(hook, opts, weighter, reducer)
s_hook = SummaryHook({"total_loss": hook})
self.hook = ChainHook(hook, s_hook)
|