Skip to content

adabn

AdaBNHook

Bases: BaseWrapperHook

Passes inputs into model without doing any optimization. The model is expected to receive a domain argument and update its BatchNorm parameters itself.

Source code in pytorch_adapt\hooks\adabn.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class AdaBNHook(BaseWrapperHook):
    """
    Passes inputs into model without doing any optimization.
    The model is expected to receive a ```domain``` argument
    and update its BatchNorm parameters itself.
    """

    def __init__(self, domains=None, **kwargs):
        super().__init__(**kwargs)
        domains = c_f.default(domains, ["src", "target"])
        hooks = []
        for d in domains:
            f_hook = DomainSpecificFeaturesHook(domains=[d], detach=True)
            l_hook = DomainSpecificLogitsHook(domains=[d], detach=True)
            hooks.append(FeaturesChainHook(f_hook, l_hook))
        self.hook = ParallelHook(*hooks)

    def call(self, inputs, losses):
        with torch.no_grad():
            return self.hook(inputs, losses)