Bases: pl.LightningModule
Converts an Adapter into a PyTorch
Lightning module.
Source code in pytorch_adapt\frameworks\lightning\lightning.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88 | class Lightning(pl.LightningModule):
"""
Converts an [Adapter](../../adapters/index.md) into a PyTorch
Lightning module.
"""
def __init__(self, adapter, validator=None):
"""
Arguments:
adapter:
validator:
"""
super().__init__()
self.models = torch.nn.ModuleDict(adapter.models)
self.misc = torch.nn.ModuleDict(adapter.misc)
adapter.models = self.models
adapter.misc = self.misc
self.validator = validator
self.adapter = adapter
self.automatic_optimization = False
def forward(self, x, domain=None):
""""""
return self.adapter.inference(x, domain=domain)
def training_step(self, batch, batch_idx):
""""""
set_adapter_optimizers_to_pl(self.adapter, self.optimizers())
losses = self.adapter.training_step(
batch,
custom_backward=self.manual_backward,
)
for k, v in losses.items():
self.log(k, v)
def validation_step(self, batch, batch_idx, dataloader_idx=0):
""""""
return f_utils.collector_step(self, batch, f_utils.create_output_dict)
def validation_epoch_end(self, outputs):
""""""
required_data = self.validator.required_data
if len(required_data) > 1:
outputs = multi_dataloader_collect(outputs)
data = {k: v for k, v in zip(required_data, outputs)}
else:
outputs = single_dataloader_collect(outputs)
data = {required_data[0]: outputs}
score = self.validator(**data)
self.log("validation_score", score)
def configure_optimizers(self):
""""""
optimizers = list(self.adapter.optimizers.values())
lr_schedulers = []
for interval in ["epoch", "step"]:
for v in self.adapter.lr_schedulers.filter_by_scheduler_type(
f"per_{interval}"
):
lr_schedulers.append({"lr_scheduler": v, "interval": interval})
return optimizers, lr_schedulers
|
__init__(adapter, validator=None)
Parameters:
Name |
Type |
Description |
Default |
adapter |
|
|
required
|
validator |
|
|
None
|
Source code in pytorch_adapt\frameworks\lightning\lightning.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47 | def __init__(self, adapter, validator=None):
"""
Arguments:
adapter:
validator:
"""
super().__init__()
self.models = torch.nn.ModuleDict(adapter.models)
self.misc = torch.nn.ModuleDict(adapter.misc)
adapter.models = self.models
adapter.misc = self.misc
self.validator = validator
self.adapter = adapter
self.automatic_optimization = False
|