Skip to content

gan

CDAN

Bases: GAN

Wraps CDANHook.

Container Required keys
models ["G", "C", "D"]
optimizers ["G", "C", "D"]
misc ["feature_combiner"]
Source code in pytorch_adapt\adapters\gan.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class CDAN(GAN):
    """
    Wraps [CDANHook][pytorch_adapt.hooks.CDANHook].

    |Container|Required keys|
    |---|---|
    |models|```["G", "C", "D"]```|
    |optimizers|```["G", "C", "D"]```|
    |misc|```["feature_combiner"]```|
    """

    @property
    def hook_cls(self):
        return CDANHook

    def get_key_enforcer(self) -> KeyEnforcer:
        ke = super().get_key_enforcer()
        ke.requirements["misc"] = ["feature_combiner"]
        return ke

DomainConfusion

Bases: GAN

Wraps DomainConfusionHook.

Container Required keys
models ["G", "C", "D"]
optimizers ["G", "C", "D"]
Source code in pytorch_adapt\adapters\gan.py
69
70
71
72
73
74
75
76
77
78
79
80
81
class DomainConfusion(GAN):
    """
    Wraps [DomainConfusionHook][pytorch_adapt.hooks.DomainConfusionHook].

    |Container|Required keys|
    |---|---|
    |models|```["G", "C", "D"]```|
    |optimizers|```["G", "C", "D"]```|
    """

    @property
    def hook_cls(self):
        return DomainConfusionHook

GAN

Bases: BaseGCDAdapter

Wraps GANHook.

Container Required keys
models ["G", "C", "D"]
optimizers ["G", "C", "D"]
Source code in pytorch_adapt\adapters\gan.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class GAN(BaseGCDAdapter):
    """
    Wraps [GANHook][pytorch_adapt.hooks.GANHook].

    |Container|Required keys|
    |---|---|
    |models|```["G", "C", "D"]```|
    |optimizers|```["G", "C", "D"]```|
    """

    def init_hook(self, hook_kwargs):
        g_opts = with_opt(["G", "C"])
        d_opts = with_opt(["D"])
        self.hook = self.hook_cls(d_opts=d_opts, g_opts=g_opts, **hook_kwargs)

    @property
    def hook_cls(self):
        return GANHook

VADA

Bases: GAN

Wraps VADAHook.

Container Required keys
models ["G", "C", "D"]
optimizers ["G", "C", "D"]
misc ["combined_model"]

The "combined_model" key does not need to be passed in. It is simply torch.nn.Sequential(G, C), and is set automatically during initialization.

Source code in pytorch_adapt\adapters\gan.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
class VADA(GAN):
    """
    Wraps [VADAHook][pytorch_adapt.hooks.VADAHook].

    |Container|Required keys|
    |---|---|
    |models|```["G", "C", "D"]```|
    |optimizers|```["G", "C", "D"]```|
    |misc|```["combined_model"]```|

    The ```"combined_model"``` key does not need to be passed in.
    It is simply ```torch.nn.Sequential(G, C)```, and is set
    automatically during initialization.
    """

    @property
    def hook_cls(self):
        return VADAHook

    def init_containers_and_check_keys(self, containers):
        models = containers["models"]
        misc = containers["misc"]
        misc["combined_model"] = torch.nn.Sequential(models["G"], models["C"])
        super().init_containers_and_check_keys(containers)

    def get_key_enforcer(self) -> KeyEnforcer:
        ke = super().get_key_enforcer()
        ke.requirements["misc"] = ["combined_model"]
        return ke