Skip to content

base_container

BaseContainer

Bases: MutableMapping

The parent class of all containers.

Containers are dictionaries with extra functionality that simplify object creation.

Source code in pytorch_adapt\containers\base_container.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class BaseContainer(MutableMapping):
    """
    The parent class of all containers.

    Containers are dictionaries with extra functionality
    that simplify object creation.
    """

    def __init__(self, store, other=None, keys=None):
        """
        Arguments:
            store: A tuple or dictionary

                - A tuple consists of ```(<class_ref>, <init kwargs>)```.
                For example, ```(torch.optim.Adam, {"lr": 0.1})```
                - A dictionary maps from object name to either a tuple
                or a fully constructed object.

            other: Another container which is used in the process
                of creating objects in this container, (e.g
                optimizers require model parameters).

            keys: Converts ```store``` from tuple to dict format,
                where each dict value is the tuple. This only works
                if ```store``` is passed in as a tuple.
        """
        if not isinstance(store, (tuple, dict)):
            raise TypeError("BaseContainer input must be a tuple or dict")
        if isinstance(store, tuple):
            self.store_as_tuple = store
            self.store = {}
        else:
            self.store_as_tuple = None
            self.store = store
        if keys is not None:
            self.duplicate(keys)
        if other is not None:
            self.create_with(other)

    def __getitem__(self, key):
        return self.store[self._keytransform(key)]

    def __setitem__(self, key, value):
        self.store[self._keytransform(key)] = value

    def __delitem__(self, key):
        del self.store[self._keytransform(key)]

    def __iter__(self):
        return iter(self.store)

    def __len__(self):
        return len(self.store)

    def _keytransform(self, key):
        return key

    def __repr__(self):
        if isinstance(self.store, dict):
            output = ""
            for k, v in self.items():
                output += f"{k}: {v}\n"
            return output
        return str(self.store)

    def merge(self, other: "BaseContainer"):
        """
        Merges another container into this one.
        Arguments:
            other: The container that will be merged into this container.
        """
        if not isinstance(other, BaseContainer):
            raise TypeError("merge can only be done with another container")
        if other.store_as_tuple:
            if len(self) > 0:
                for k, v in self.items():
                    self[k] = other.store_as_tuple
            else:
                self.store_as_tuple = other.store_as_tuple
        else:
            for k, v in other.items():
                self[k] = v

    def create(self):
        """
        Initializes objects by converting all
        tuples in the store into objects.
        """
        for k, v in self.items():
            if isinstance(v, tuple):
                if len(v) == 2:
                    class_ref, kwargs = v
                    self[k] = class_ref(**kwargs)
                elif len(v) == 1:
                    self[k] = v[0]
                else:
                    raise ValueError(
                        f"The tuple {v} has length={len(v)}, but it must be of length 1 or 2"
                    )
        self.delete_unwanted_keys()

    def create_with(self, other):
        """
        Initializes objects conditioned on the input container.
        """
        self.store_as_tuple = self.type_check(self.store_as_tuple, other)
        self.store = self.type_check(self.store, other)
        self.store_as_tuple.update(self.store)
        self.store = self.store_as_tuple
        self.store_as_tuple = None
        self.delete_unwanted_keys()
        self._create_with(other)

    def _create_with(self, other):
        pass

    def type_check(self, store, other):
        if isinstance(store, tuple):
            return {k: store for k in other.keys()}
        elif isinstance(store, dict):
            return store
        elif store is None:
            return {}

    def duplicate(self, keys):
        if isinstance(self.store_as_tuple, tuple):
            self.store = {k: self.store_as_tuple for k in keys}
            self.store_as_tuple = None
        else:
            raise TypeError("If keys are specified, store must be a tuple.")

    def apply(self, function, keys=None):
        if keys is None:
            keys = list(self.keys())
        for k in keys:
            self[k] = function(self[k])

    def delete_unwanted_keys(self):
        del_list = []
        for k, v in self.items():
            if isinstance(v, DeleteKey) or (isinstance(v, tuple) and v[0] == DeleteKey):
                del_list.append(k)
        for k in del_list:
            del self[k]

    def state_dict(self):
        return {k: v.state_dict() for k, v in self.items() if hasattr(v, "state_dict")}

    def load_state_dict(self, state_dict):
        c_f.assert_state_dict_keys(state_dict, self.keys())
        for k, v in state_dict.items():
            self[k].load_state_dict(v)

__init__(store, other=None, keys=None)

Parameters:

Name Type Description Default
store

A tuple or dictionary

  • A tuple consists of (<class_ref>, <init kwargs>). For example, (torch.optim.Adam, {"lr": 0.1})
  • A dictionary maps from object name to either a tuple or a fully constructed object.
required
other

Another container which is used in the process of creating objects in this container, (e.g optimizers require model parameters).

None
keys

Converts store from tuple to dict format, where each dict value is the tuple. This only works if store is passed in as a tuple.

None
Source code in pytorch_adapt\containers\base_container.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(self, store, other=None, keys=None):
    """
    Arguments:
        store: A tuple or dictionary

            - A tuple consists of ```(<class_ref>, <init kwargs>)```.
            For example, ```(torch.optim.Adam, {"lr": 0.1})```
            - A dictionary maps from object name to either a tuple
            or a fully constructed object.

        other: Another container which is used in the process
            of creating objects in this container, (e.g
            optimizers require model parameters).

        keys: Converts ```store``` from tuple to dict format,
            where each dict value is the tuple. This only works
            if ```store``` is passed in as a tuple.
    """
    if not isinstance(store, (tuple, dict)):
        raise TypeError("BaseContainer input must be a tuple or dict")
    if isinstance(store, tuple):
        self.store_as_tuple = store
        self.store = {}
    else:
        self.store_as_tuple = None
        self.store = store
    if keys is not None:
        self.duplicate(keys)
    if other is not None:
        self.create_with(other)

create()

Initializes objects by converting all tuples in the store into objects.

Source code in pytorch_adapt\containers\base_container.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def create(self):
    """
    Initializes objects by converting all
    tuples in the store into objects.
    """
    for k, v in self.items():
        if isinstance(v, tuple):
            if len(v) == 2:
                class_ref, kwargs = v
                self[k] = class_ref(**kwargs)
            elif len(v) == 1:
                self[k] = v[0]
            else:
                raise ValueError(
                    f"The tuple {v} has length={len(v)}, but it must be of length 1 or 2"
                )
    self.delete_unwanted_keys()

create_with(other)

Initializes objects conditioned on the input container.

Source code in pytorch_adapt\containers\base_container.py
108
109
110
111
112
113
114
115
116
117
118
def create_with(self, other):
    """
    Initializes objects conditioned on the input container.
    """
    self.store_as_tuple = self.type_check(self.store_as_tuple, other)
    self.store = self.type_check(self.store, other)
    self.store_as_tuple.update(self.store)
    self.store = self.store_as_tuple
    self.store_as_tuple = None
    self.delete_unwanted_keys()
    self._create_with(other)

merge(other)

Merges another container into this one.

Parameters:

Name Type Description Default
other 'BaseContainer'

The container that will be merged into this container.

required
Source code in pytorch_adapt\containers\base_container.py
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def merge(self, other: "BaseContainer"):
    """
    Merges another container into this one.
    Arguments:
        other: The container that will be merged into this container.
    """
    if not isinstance(other, BaseContainer):
        raise TypeError("merge can only be done with another container")
    if other.store_as_tuple:
        if len(self) > 0:
            for k, v in self.items():
                self[k] = other.store_as_tuple
        else:
            self.store_as_tuple = other.store_as_tuple
    else:
        for k, v in other.items():
            self[k] = v