Skip to content

aligner

Aligner

Bases: BaseGCAdapter

Wraps AlignerPlusCHook.

Container Required keys
models ["G", "C"]
optimizers ["G", "C"]
Source code in pytorch_adapt\adapters\aligner.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Aligner(BaseGCAdapter):
    """
    Wraps [AlignerPlusCHook][pytorch_adapt.hooks.AlignerPlusCHook].

    |Container|Required keys|
    |---|---|
    |models|```["G", "C"]```|
    |optimizers|```["G", "C"]```|
    """

    def init_hook(self, hook_kwargs):
        opts = with_opt(list(self.optimizers.keys()))
        self.hook = self.hook_cls(opts, **hook_kwargs)

    @property
    def hook_cls(self):
        return AlignerPlusCHook

RTN

Bases: Aligner

Wraps RTNHook.

Container Required keys
models ["G", "C", "residual_model"]
optimizers ["G", "C", "residual_model"]
misc ["feature_combiner"]
Source code in pytorch_adapt\adapters\aligner.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class RTN(Aligner):
    """
    Wraps [RTNHook][pytorch_adapt.hooks.RTNHook].

    |Container|Required keys|
    |---|---|
    |models|```["G", "C", "residual_model"]```|
    |optimizers|```["G", "C", "residual_model"]```|
    |misc|```["feature_combiner"]```|
    """

    def __init__(self, *args, inference_fn=None, **kwargs):
        """
        Arguments:
            inference_fn: Default is [rtn_fn][pytorch_adapt.inference.rtn_fn]
        """
        inference_fn = c_f.default(inference_fn, rtn_fn)
        super().__init__(*args, inference_fn=inference_fn, **kwargs)

    @property
    def hook_cls(self):
        return RTNHook

    def get_key_enforcer(self) -> KeyEnforcer:
        ke = super().get_key_enforcer()
        ke.requirements["models"].append("residual_model")
        ke.requirements["optimizers"].append("residual_model")
        ke.requirements["misc"] = ["feature_combiner"]
        return ke

__init__(*args, inference_fn=None, **kwargs)

Parameters:

Name Type Description Default
inference_fn

Default is rtn_fn

None
Source code in pytorch_adapt\adapters\aligner.py
39
40
41
42
43
44
45
def __init__(self, *args, inference_fn=None, **kwargs):
    """
    Arguments:
        inference_fn: Default is [rtn_fn][pytorch_adapt.inference.rtn_fn]
    """
    inference_fn = c_f.default(inference_fn, rtn_fn)
    super().__init__(*args, inference_fn=inference_fn, **kwargs)