Skip to content

lightning

Lightning

Bases: pl.LightningModule

Converts an Adapter into a PyTorch Lightning module.

Source code in pytorch_adapt\frameworks\lightning\lightning.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class Lightning(pl.LightningModule):
    """
    Converts an [Adapter](../../adapters/index.md) into a PyTorch
    Lightning module.
    """

    def __init__(self, adapter, validator=None):
        """
        Arguments:
            adapter:
            validator:
        """
        super().__init__()
        self.models = torch.nn.ModuleDict(adapter.models)
        self.misc = torch.nn.ModuleDict(adapter.misc)
        adapter.models = self.models
        adapter.misc = self.misc
        self.validator = validator
        self.adapter = adapter
        self.automatic_optimization = False

    def forward(self, x, domain=None):
        """"""
        return self.adapter.inference(x, domain=domain)

    def training_step(self, batch, batch_idx):
        """"""
        set_adapter_optimizers_to_pl(self.adapter, self.optimizers())
        losses = self.adapter.training_step(
            batch,
            custom_backward=self.manual_backward,
        )
        for k, v in losses.items():
            self.log(k, v)

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """"""
        return f_utils.collector_step(self, batch, f_utils.create_output_dict)

    def validation_epoch_end(self, outputs):
        """"""
        required_data = self.validator.required_data
        if len(required_data) > 1:
            outputs = multi_dataloader_collect(outputs)
            data = {k: v for k, v in zip(required_data, outputs)}
        else:
            outputs = single_dataloader_collect(outputs)
            data = {required_data[0]: outputs}
        score = self.validator(**data)
        self.log("validation_score", score)

    def configure_optimizers(self):
        """"""
        optimizers = list(self.adapter.optimizers.values())
        lr_schedulers = []
        for interval in ["epoch", "step"]:
            for v in self.adapter.lr_schedulers.filter_by_scheduler_type(
                f"per_{interval}"
            ):
                lr_schedulers.append({"lr_scheduler": v, "interval": interval})
        return optimizers, lr_schedulers

__init__(adapter, validator=None)

Parameters:

Name Type Description Default
adapter required
validator None
Source code in pytorch_adapt\frameworks\lightning\lightning.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(self, adapter, validator=None):
    """
    Arguments:
        adapter:
        validator:
    """
    super().__init__()
    self.models = torch.nn.ModuleDict(adapter.models)
    self.misc = torch.nn.ModuleDict(adapter.misc)
    adapter.models = self.models
    adapter.misc = self.misc
    self.validator = validator
    self.adapter = adapter
    self.automatic_optimization = False