Bases: BaseWrapperHook
Passes inputs into model without doing any optimization.
The model is expected to receive a domain
argument
and update its BatchNorm parameters itself.
Source code in pytorch_adapt\hooks\adabn.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49 | class AdaBNHook(BaseWrapperHook):
"""
Passes inputs into model without doing any optimization.
The model is expected to receive a ```domain``` argument
and update its BatchNorm parameters itself.
"""
def __init__(self, domains=None, **kwargs):
super().__init__(**kwargs)
domains = c_f.default(domains, ["src", "target"])
hooks = []
for d in domains:
f_hook = DomainSpecificFeaturesHook(domains=[d], detach=True)
l_hook = DomainSpecificLogitsHook(domains=[d], detach=True)
hooks.append(FeaturesChainHook(f_hook, l_hook))
self.hook = ParallelHook(*hooks)
def call(self, inputs, losses):
with torch.no_grad():
return self.hook(inputs, losses)
|