Bases: BaseGCAdapter
Wraps SymNetsHook.
Container |
Required keys |
models |
["G", "C"] |
optimizers |
["G", "C"] |
The C model must output a list of logits: [logits1, logits2]
.
Source code in pytorch_adapt\adapters\symnets.py
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31 | class SymNets(BaseGCAdapter):
"""
Wraps [SymNetsHook][pytorch_adapt.hooks.SymNetsHook].
|Container|Required keys|
|---|---|
|models|```["G", "C"]```|
|optimizers|```["G", "C"]```|
The C model must output a list of logits: ```[logits1, logits2]```.
"""
def __init__(self, *args, inference_fn=None, **kwargs):
inference_fn = c_f.default(inference_fn, symnets_fn)
super().__init__(*args, inference_fn=inference_fn, **kwargs)
def init_hook(self, hook_kwargs):
self.hook = self.hook_cls(
g_opts=with_opt(["G"]), c_opts=with_opt(["C"]), **hook_kwargs
)
@property
def hook_cls(self):
return SymNetsHook
|