Skip to content

lr_schedulers

LRSchedulers

Bases: BaseContainer

A container for optimizer learning rate schedulers.

Source code in pytorch_adapt\containers\lr_schedulers.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class LRSchedulers(BaseContainer):
    """
    A container for optimizer learning rate schedulers.
    """

    def __init__(self, store, scheduler_types=None, **kwargs):
        """
        Arguments:
            store: See [```BaseContainer```][pytorch_adapt.containers.BaseContainer]
            scheduler_types: A dictionary mapping from
                scheduler type (```"per_step"``` or ```"per_epoch"```)
                to a list of object names. If ```None```, then all
                schedulers are assumed to be ```"per_step"```
            **kwargs: [```BaseContainer```][pytorch_adapt.containers.BaseContainer]
                keyword arguments.
        """
        self.scheduler_types = scheduler_types
        super().__init__(store, **kwargs)

    def _create_with(self, other):
        to_be_deleted = []
        for k, v in self.items():
            try:
                class_ref, kwargs = v
            except TypeError:
                continue
            optimizer = other[k]
            if not c_f.is_optimizer(optimizer):
                to_be_deleted.append(k)
            else:
                self[k] = class_ref(optimizer, **kwargs)

        for k in to_be_deleted:
            del self[k]

    def step(self, scheduler_type: str):
        """
        Step the lr schedulers of the specified type.
        Arguments:
            scheduler_type: ```"per_step"``` or ```"per_epoch"```
        """
        for v in self.filter_by_scheduler_type(scheduler_type):
            v.step()

    def filter_by_scheduler_type(self, x):
        if self.scheduler_types is not None:
            return [v for k, v in self.items() if k in self.scheduler_types[x]]
        elif x == "per_step":
            return self.values()
        elif x == "per_epoch":
            return []
        else:
            raise ValueError(
                f"scheduler types are 'per_step' or 'per_epoch', but input is '{x}'"
            )

    def merge(self, other):
        super().merge(other)
        if other.scheduler_types is not None:
            if self.scheduler_types is not None:
                for k, v in other.scheduler_types.items():
                    curr_list = self.scheduler_types[k]
                    curr_list.extend(v)
                    self.scheduler_types[k] = list(set(curr_list))
            else:
                self.scheduler_types = other.scheduler_types

__init__(store, scheduler_types=None, **kwargs)

Parameters:

Name Type Description Default
store

See BaseContainer

required
scheduler_types

A dictionary mapping from scheduler type ("per_step" or "per_epoch") to a list of object names. If None, then all schedulers are assumed to be "per_step"

None
**kwargs

BaseContainer keyword arguments.

{}
Source code in pytorch_adapt\containers\lr_schedulers.py
10
11
12
13
14
15
16
17
18
19
20
21
22
def __init__(self, store, scheduler_types=None, **kwargs):
    """
    Arguments:
        store: See [```BaseContainer```][pytorch_adapt.containers.BaseContainer]
        scheduler_types: A dictionary mapping from
            scheduler type (```"per_step"``` or ```"per_epoch"```)
            to a list of object names. If ```None```, then all
            schedulers are assumed to be ```"per_step"```
        **kwargs: [```BaseContainer```][pytorch_adapt.containers.BaseContainer]
            keyword arguments.
    """
    self.scheduler_types = scheduler_types
    super().__init__(store, **kwargs)

step(scheduler_type)

Step the lr schedulers of the specified type.

Parameters:

Name Type Description Default
scheduler_type str

"per_step" or "per_epoch"

required
Source code in pytorch_adapt\containers\lr_schedulers.py
40
41
42
43
44
45
46
47
def step(self, scheduler_type: str):
    """
    Step the lr schedulers of the specified type.
    Arguments:
        scheduler_type: ```"per_step"``` or ```"per_epoch"```
    """
    for v in self.filter_by_scheduler_type(scheduler_type):
        v.step()