Skip to content

mcd

MCDHook

Bases: BaseWrapperHook

Implementation of Maximum Classifier Discrepancy for Unsupervised Domain Adaptation.

Source code in pytorch_adapt\hooks\mcd.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class MCDHook(BaseWrapperHook):
    """
    Implementation of
    [Maximum Classifier Discrepancy for Unsupervised Domain Adaptation](https://arxiv.org/abs/1712.02560).
    """

    def __init__(
        self,
        g_opts,
        c_opts,
        discrepancy_loss_fn=None,
        x_weighter=None,
        x_reducer=None,
        y_weighter=None,
        y_reducer=None,
        z_weighter=None,
        z_reducer=None,
        pre_x=None,
        post_x=None,
        pre_y=None,
        post_y=None,
        pre_z=None,
        post_z=None,
        repeat=4,
        **kwargs,
    ):
        super().__init__(**kwargs)
        [pre_x, post_x, pre_y, post_y, pre_z, post_z] = c_f.many_default(
            [pre_x, post_x, pre_y, post_y, pre_z, post_z], [[], [], [], [], [], []]
        )
        x = ChainHook(*pre_x, MultipleCLossHook(), *post_x)
        y = ChainHook(
            *pre_y,
            MultipleCLossHook(detach_features=True),
            MCDLossHook(
                detach_features=True, minimize=False, loss_fn=discrepancy_loss_fn
            ),
            *post_y,
        )
        z = ChainHook(*pre_z, MCDLossHook(loss_fn=discrepancy_loss_fn), *post_z)

        x = OptimizerHook(x, [*c_opts, *g_opts], x_weighter, x_reducer)
        y = OptimizerHook(y, c_opts, y_weighter, y_reducer)
        z = OptimizerHook(z, g_opts, z_weighter, z_reducer)
        s_hook = SummaryHook({"x_loss": x, "y_loss": y, "z_loss": z})
        z = RepeatHook(z, repeat, keep_only_last=True)

        self.hook = ChainHook(ParallelHook(x, y, z), s_hook)