Skip to content

checkpoint_utils

CheckpointFnCreator

This class creates a checkpointing function for use with the Ignite wrapper.

Source code in pytorch_adapt\frameworks\ignite\checkpoint_utils.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class CheckpointFnCreator:
    """
    This class creates a checkpointing function
    for use with the [```Ignite```][pytorch_adapt.frameworks.ignite.Ignite] wrapper.
    """

    def __init__(self, **kwargs):
        """
        Arguments:
            **kwargs: Optional arguments that will be passed to PyTorch Ignite's
                [```ModelCheckpoint```](https://pytorch.org/ignite/v0.4.8/generated/ignite.handlers.checkpoint.ModelCheckpoint.html)
                class.
        """
        self.kwargs = {
            "filename_prefix": "",
            "global_step_transform": global_step_transform,
            "filename_pattern": "{filename_prefix}{name}_{global_step}.{ext}",
            **kwargs,
        }
        # Create handler here in case needed by load_objects or last_checkpoint
        # before __call__ is used
        self.objs = ModelCheckpoint(**self.kwargs)

        # For saving self.objs. Only save the very latest (n_saved = 1)
        self.ckpter = ModelCheckpoint(**{**self.kwargs, "n_saved": 1})

    def __call__(
        self,
        adapter=None,
        validator=None,
        val_hooks=None,
        **kwargs,
    ):
        """
        Creates the checkpointing function.
        Arguments:
            adapter: An [```Adapter```][pytorch_adapt.adapters.BaseAdapter] object.
            validator: A [```ScoreHistory```][pytorch_adapt.validators.ScoreHistory] object.
            val_hooks: A list of functions called during validation.
                See [```Ignite```][pytorch_adapt.frameworks.ignite.Ignite] for details.
        """
        self.objs = ModelCheckpoint(**{**self.kwargs, **kwargs})
        dict_to_save = {}
        if adapter:
            dict_to_save.update(adapter_to_dict(adapter))
        if validator:
            dict_to_save["validator"] = validator
        if val_hooks:
            dict_to_save.update(val_hooks_to_dict(val_hooks))

        def fn(engine):
            self.objs(engine, {"engine": engine, **dict_to_save})
            self.ckpter(engine, {"checkpointer": self.objs})

        return fn

    def load_objects(self, to_load, checkpoint=None, global_step=None):
        to_load = {k: v for k, v in to_load.items() if v}
        if global_step is not None:
            self.objs.reload_objects(
                to_load, name="checkpoint", global_step=global_step
            )
        else:
            self.objs.load_objects(to_load, str(checkpoint))

    def load_best_checkpoint(self, to_load):
        last_checkpoint = self.get_best_checkpoint()
        self.load_objects(to_load, last_checkpoint)

    def get_best_checkpoint(self):
        if self.objs.last_checkpoint:
            return self.objs.last_checkpoint

        ckpter_last_checkpoint = self.ckpter.last_checkpoint
        if not ckpter_last_checkpoint:
            files = glob.glob(
                os.path.join(self.ckpter.save_handler.dirname, "*checkpointer*.pt")
            )
            if len(files) > 1:
                raise ValueError("there should only be 1 matching checkpointer file")
            ckpter_last_checkpoint = files[0]

        self.ckpter.load_objects(
            {"checkpointer": self.objs}, str(ckpter_last_checkpoint)
        )
        return self.objs.last_checkpoint

__call__(adapter=None, validator=None, val_hooks=None, **kwargs)

Creates the checkpointing function.

Parameters:

Name Type Description Default
adapter

An Adapter object.

None
validator

A ScoreHistory object.

None
val_hooks

A list of functions called during validation. See Ignite for details.

None
Source code in pytorch_adapt\frameworks\ignite\checkpoint_utils.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def __call__(
    self,
    adapter=None,
    validator=None,
    val_hooks=None,
    **kwargs,
):
    """
    Creates the checkpointing function.
    Arguments:
        adapter: An [```Adapter```][pytorch_adapt.adapters.BaseAdapter] object.
        validator: A [```ScoreHistory```][pytorch_adapt.validators.ScoreHistory] object.
        val_hooks: A list of functions called during validation.
            See [```Ignite```][pytorch_adapt.frameworks.ignite.Ignite] for details.
    """
    self.objs = ModelCheckpoint(**{**self.kwargs, **kwargs})
    dict_to_save = {}
    if adapter:
        dict_to_save.update(adapter_to_dict(adapter))
    if validator:
        dict_to_save["validator"] = validator
    if val_hooks:
        dict_to_save.update(val_hooks_to_dict(val_hooks))

    def fn(engine):
        self.objs(engine, {"engine": engine, **dict_to_save})
        self.ckpter(engine, {"checkpointer": self.objs})

    return fn

__init__(**kwargs)

Parameters:

Name Type Description Default
**kwargs

Optional arguments that will be passed to PyTorch Ignite's ModelCheckpoint class.

{}
Source code in pytorch_adapt\frameworks\ignite\checkpoint_utils.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(self, **kwargs):
    """
    Arguments:
        **kwargs: Optional arguments that will be passed to PyTorch Ignite's
            [```ModelCheckpoint```](https://pytorch.org/ignite/v0.4.8/generated/ignite.handlers.checkpoint.ModelCheckpoint.html)
            class.
    """
    self.kwargs = {
        "filename_prefix": "",
        "global_step_transform": global_step_transform,
        "filename_pattern": "{filename_prefix}{name}_{global_step}.{ext}",
        **kwargs,
    }
    # Create handler here in case needed by load_objects or last_checkpoint
    # before __call__ is used
    self.objs = ModelCheckpoint(**self.kwargs)

    # For saving self.objs. Only save the very latest (n_saved = 1)
    self.ckpter = ModelCheckpoint(**{**self.kwargs, "n_saved": 1})