Skip to content

vada

VADAHook

Bases: GANHook

Implementation of VADA from A DIRT-T Approach to Unsupervised Domain Adaptation.

Source code in pytorch_adapt\hooks\vada.py
51
52
53
54
55
56
57
58
59
60
class VADAHook(GANHook):
    """
    Implementation of VADA from
    [A DIRT-T Approach to Unsupervised Domain Adaptation](https://arxiv.org/abs/1802.08735).
    """

    def __init__(self, vat_loss_fn=None, entropy_loss_fn=None, post_g=None, **kwargs):
        post_g = c_f.default(post_g, [])
        post_g += [VATPlusEntropyHook(vat_loss_fn, entropy_loss_fn)]
        super().__init__(post_g=post_g, **kwargs)

VATHook

Bases: BaseWrapperHook

Applies the VATLoss.

Source code in pytorch_adapt\hooks\vada.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class VATHook(BaseWrapperHook):
    """
    Applies the [```VATLoss```][pytorch_adapt.layers.VATLoss].
    """

    def __init__(self, loss_fn=None, **kwargs):
        super().__init__(**kwargs)
        self.loss_fn = c_f.default(loss_fn, VATLoss, {})
        self.hook = FeaturesAndLogitsHook()

    def call(self, inputs, losses):
        outputs = self.hook(inputs, losses)[0]
        [src_imgs, target_imgs, combined_model] = c_f.extract(
            inputs, ["src_imgs", "target_imgs", "combined_model"]
        )
        [src_logits, target_logits] = c_f.extract(
            [outputs, inputs],
            c_f.filter(self.hook.out_keys, "_logits$", ["^src", "^target"]),
        )
        src_vat_loss = self.loss_fn(src_imgs, src_logits, combined_model)
        target_vat_loss = self.loss_fn(target_imgs, target_logits, combined_model)
        return (
            outputs,
            {
                "src_vat_loss": src_vat_loss,
                "target_vat_loss": target_vat_loss,
            },
        )

    def _loss_keys(self):
        return ["src_vat_loss", "target_vat_loss"]