Skip to content

classifier

Classifier

Bases: BaseGCAdapter

Wraps ClassifierHook.

Container Required keys
models ["G", "C"]
optimizers ["G", "C"]
Source code in pytorch_adapt\adapters\classifier.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Classifier(BaseGCAdapter):
    """
    Wraps [ClassifierHook][pytorch_adapt.hooks.ClassifierHook].

    |Container|Required keys|
    |---|---|
    |models|```["G", "C"]```|
    |optimizers|```["G", "C"]```|
    """

    def init_hook(self, hook_kwargs):
        opts = with_opt(list(self.optimizers.keys()))
        self.hook = self.hook_cls(opts=opts, **hook_kwargs)

    @property
    def hook_cls(self):
        return ClassifierHook

Finetuner

Bases: Classifier

Wraps FinetunerHook.

Container Required keys
models ["G", "C"]
optimizers ["C"]
Source code in pytorch_adapt\adapters\classifier.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class Finetuner(Classifier):
    """
    Wraps [FinetunerHook][pytorch_adapt.hooks.FinetunerHook].

    |Container|Required keys|
    |---|---|
    |models|```["G", "C"]```|
    |optimizers|```["C"]```|
    """

    @property
    def hook_cls(self):
        return FinetunerHook

    def get_default_containers(self) -> MultipleContainers:
        optimizers = Optimizers(default_optimizer_tuple(), keys=["C"])
        return MultipleContainers(optimizers=optimizers)

    def get_key_enforcer(self) -> KeyEnforcer:
        ke = super().get_key_enforcer()
        ke.requirements["optimizers"].remove("G")
        return ke