Skip to content

adda

ADDAHook

Bases: GANHook

Implementation of Adversarial Discriminative Domain Adaptation.

Source code in pytorch_adapt\hooks\adda.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ADDAHook(GANHook):
    """
    Implementation of
    [Adversarial Discriminative Domain Adaptation](https://arxiv.org/abs/1702.05464).
    """

    def __init__(self, threshold: float = 0.6, pre_g=None, post_g=None, **kwargs):
        """
        Arguments:
            threshold: In each training iteration, the generator is only updated
                if the discriminator's accuracy is greater than ```threshold```.
        """
        [pre_g, post_g] = c_f.many_default([pre_g, post_g], [[], []])
        sf_frozen = FrozenModelHook(FeaturesHook(detach=True, domains=["src"]), "G")
        tf_all = FeaturesWithGradAndDetachedHook(model_name="T", domains=["target"])
        pre_d = ChainHook(sf_frozen, tf_all)
        num_pre_g = len(pre_g)
        gen_conditions = [TrueHook() for _ in range(num_pre_g + len(post_g) + 2)]
        # generator condition, classifier condition
        gen_conditions[num_pre_g : num_pre_g + 2] = [
            StrongDHook(threshold),
            FalseHook(),
        ]
        super().__init__(
            pre_d=[pre_d],
            pre_g=pre_g,
            post_g=post_g,
            gen_conditions=gen_conditions,
            gen_domains=["target"],
            c_hook=EmptyHook(),
            **kwargs
        )

__init__(threshold=0.6, pre_g=None, post_g=None, **kwargs)

Parameters:

Name Type Description Default
threshold float

In each training iteration, the generator is only updated if the discriminator's accuracy is greater than threshold.

0.6
Source code in pytorch_adapt\hooks\adda.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __init__(self, threshold: float = 0.6, pre_g=None, post_g=None, **kwargs):
    """
    Arguments:
        threshold: In each training iteration, the generator is only updated
            if the discriminator's accuracy is greater than ```threshold```.
    """
    [pre_g, post_g] = c_f.many_default([pre_g, post_g], [[], []])
    sf_frozen = FrozenModelHook(FeaturesHook(detach=True, domains=["src"]), "G")
    tf_all = FeaturesWithGradAndDetachedHook(model_name="T", domains=["target"])
    pre_d = ChainHook(sf_frozen, tf_all)
    num_pre_g = len(pre_g)
    gen_conditions = [TrueHook() for _ in range(num_pre_g + len(post_g) + 2)]
    # generator condition, classifier condition
    gen_conditions[num_pre_g : num_pre_g + 2] = [
        StrongDHook(threshold),
        FalseHook(),
    ]
    super().__init__(
        pre_d=[pre_d],
        pre_g=pre_g,
        post_g=post_g,
        gen_conditions=gen_conditions,
        gen_domains=["target"],
        c_hook=EmptyHook(),
        **kwargs
    )