class MCDHook(BaseWrapperHook):
"""
Implementation of
[Maximum Classifier Discrepancy for Unsupervised Domain Adaptation](https://arxiv.org/abs/1712.02560).
"""
def __init__(
self,
g_opts,
c_opts,
discrepancy_loss_fn=None,
x_weighter=None,
x_reducer=None,
y_weighter=None,
y_reducer=None,
z_weighter=None,
z_reducer=None,
pre_x=None,
post_x=None,
pre_y=None,
post_y=None,
pre_z=None,
post_z=None,
repeat=4,
**kwargs,
):
super().__init__(**kwargs)
[pre_x, post_x, pre_y, post_y, pre_z, post_z] = c_f.many_default(
[pre_x, post_x, pre_y, post_y, pre_z, post_z], [[], [], [], [], [], []]
)
x = ChainHook(*pre_x, MultipleCLossHook(), *post_x)
y = ChainHook(
*pre_y,
MultipleCLossHook(detach_features=True),
MCDLossHook(
detach_features=True, minimize=False, loss_fn=discrepancy_loss_fn
),
*post_y,
)
z = ChainHook(*pre_z, MCDLossHook(loss_fn=discrepancy_loss_fn), *post_z)
x = OptimizerHook(x, [*c_opts, *g_opts], x_weighter, x_reducer)
y = OptimizerHook(y, c_opts, y_weighter, y_reducer)
z = OptimizerHook(z, g_opts, z_weighter, z_reducer)
s_hook = SummaryHook({"x_loss": x, "y_loss": y, "z_loss": z})
z = RepeatHook(z, repeat, keep_only_last=True)
self.hook = ChainHook(ParallelHook(x, y, z), s_hook)