Skip to content

mcd

MCD

Bases: BaseGCAdapter

Wraps MCDHook.

Container Required keys
models ["G", "C"]
optimizers ["G", "C"]

The C model must output a list of logits, where each list element corresponds with a separate classifier. Usually the number of classifiers is 2, so C should output [logits1, logits2].

Source code in pytorch_adapt\adapters\mcd.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class MCD(BaseGCAdapter):
    """
    Wraps [MCDHook][pytorch_adapt.hooks.MCDHook].

    |Container|Required keys|
    |---|---|
    |models|```["G", "C"]```|
    |optimizers|```["G", "C"]```|

    The C model must output a list of logits, where each list element
    corresponds with a separate classifier. Usually the number of
    classifiers is 2, so C should output ```[logits1, logits2]```.
    """

    def __init__(self, *args, inference_fn=None, **kwargs):
        inference_fn = c_f.default(inference_fn, mcd_fn)
        super().__init__(*args, inference_fn=inference_fn, **kwargs)

    def init_hook(self, hook_kwargs):
        self.hook = self.hook_cls(
            g_opts=with_opt(["G"]), c_opts=with_opt(["C"]), **hook_kwargs
        )

    @property
    def hook_cls(self):
        return MCDHook