Skip to content

base_adapter

BaseAdapter

Bases: ABC

Parent class of all adapters.

Source code in pytorch_adapt\adapters\base_adapter.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class BaseAdapter(ABC):
    """
    Parent class of all adapters.
    """

    def __init__(
        self,
        models: Models = None,
        optimizers: Optimizers = None,
        lr_schedulers: LRSchedulers = None,
        misc: Misc = None,
        default_containers: MultipleContainers = None,
        key_enforcer: KeyEnforcer = None,
        inference_fn=None,
        before_training_starts=None,
        hook_kwargs: Dict[str, Any] = None,
    ):
        """
        Arguments:
            models: A [```Models```][pytorch_adapt.containers.Models] container.
                The models will be passed to the wrapped hook at each
                training iteration.
            optimizers: An [```Optimizers```][pytorch_adapt.containers.Optimizers] container.
                The optimizers will be passed to the wrapped hook at each
                training iteration.
            lr_schedulers: An [```LRSchedulers```][pytorch_adapt.containers.LRSchedulers] container.
                The lr schedulers are called automatically by the
                [```framework```](../frameworks/index.md) that wrap this adapter.
            misc: A [```Misc```][pytorch_adapt.containers.Misc] container for models
                that don't require optimizers, and other miscellaneous objects.
                These are passed into the wrapped hook at each training iteration.
            default_containers: The default set of containers to use, wrapped in a
                [```MultipleContainers```][pytorch_adapt.containers.MultipleContainers] object.
                If ```None``` then the default containers are defined in
                [```self.get_default_containers```][pytorch_adapt.adapters.BaseAdapter.get_default_containers]
            key_enforcer: A [```KeyEnforcer```][pytorch_adapt.containers.KeyEnforcer] object.
                If ```None```, then [```self.get_key_enforcer```][pytorch_adapt.adapters.BaseAdapter.get_key_enforcer]
                is used.
            inference_fn: A function that takes in:

                - ```x```: the input to the model
                - ```domain```: an integer representing the domain of the data
                - ```models```: a dictionary of models, i.e. ```self.models```
                - ```misc```: a dictionary of misc objects, i.e. ```self.misc```

            before_training_starts: A function that takes in this adapter and returns another
                function that is optionally called by a framework wrapper before training starts.
            hook_kwargs: A dictionary of keyword arguments that will be
                passed into the wrapped hook during initialization.
        """
        containers = c_f.default(default_containers, self.get_default_containers, {})
        self.key_enforcer = c_f.default(key_enforcer, self.get_key_enforcer, {})
        self.before_training_starts = c_f.class_default(
            self, before_training_starts, self.before_training_starts_default
        )

        containers.merge(
            models=models,
            optimizers=optimizers,
            lr_schedulers=lr_schedulers,
            misc=misc,
        )

        hook_kwargs = c_f.default(hook_kwargs, {})
        self.init_containers_and_check_keys(containers)
        self.init_hook(hook_kwargs)
        self.inference_fn = c_f.default(inference_fn, default_fn)

    def training_step(
        self, batch: Dict[str, Any], **kwargs
    ) -> Dict[str, Dict[str, float]]:
        """
        Calls the wrapped hook at each iteration during training.
        Arguments:
            batch: A dictionary containing training data.
            **kwargs: Any other data that will be passed into the hook.
        Returns:
            A two-level dictionary

                - the outer level is associated with a particular optimization step
                    (relevant for GAN architectures)

                - the inner level contains the loss components.
        """
        combined = c_f.assert_dicts_are_disjoint(
            self.models, self.misc, with_opt(self.optimizers), batch, kwargs
        )
        _, losses = self.hook(combined)
        return losses

    def inference(
        self, x: torch.Tensor, domain: int = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Arguments:
            x: The input to the model
            domain: An optional integer indicating the domain.

        Returns:
            Features and logits
        """
        return self.inference_fn(
            x=x,
            domain=domain,
            models=self.models,
            misc=self.misc,
        )

    def get_default_containers(self) -> MultipleContainers:
        """
        Returns:
            The default set of containers. Consists of an
                [Optimizers][pytorch_adapt.containers.Optimizers]
                container using the [default][pytorch_adapt.adapters.utils.default_optimizer_tuple]
                of Adam with lr 0.0001.
        """
        optimizers = Optimizers(default_optimizer_tuple())
        return MultipleContainers(optimizers=optimizers)

    @abstractmethod
    def get_key_enforcer(self) -> KeyEnforcer:
        """
        Returns:
            The default KeyEnforcer.
        """
        pass

    @abstractmethod
    def init_hook(self):
        """
        ```self.hook``` is initialized here.
        """
        pass

    @property
    @abstractmethod
    def hook_cls(self):
        pass

    def init_containers_and_check_keys(self, containers):
        """
        Called in ```__init__``` before
        [```init_hook```][pytorch_adapt.adapters.BaseAdapter.init_hook].
        """
        containers.create()
        self.key_enforcer.check(containers)
        for k, v in containers.items():
            setattr(self, k, v)

    def before_training_starts_default(self, framework):
        c_f.LOGGER.debug(f"models\n{self.models}")
        c_f.LOGGER.debug(f"optimizers\n{self.optimizers}")
        c_f.LOGGER.debug(f"lr_schedulers\n{self.lr_schedulers}")
        c_f.LOGGER.debug(f"misc\n{self.misc}")
        c_f.LOGGER.debug(f"hook\n{self.hook}")

__init__(models=None, optimizers=None, lr_schedulers=None, misc=None, default_containers=None, key_enforcer=None, inference_fn=None, before_training_starts=None, hook_kwargs=None)

Parameters:

Name Type Description Default
models Models

A Models container. The models will be passed to the wrapped hook at each training iteration.

None
optimizers Optimizers

An Optimizers container. The optimizers will be passed to the wrapped hook at each training iteration.

None
lr_schedulers LRSchedulers

An LRSchedulers container. The lr schedulers are called automatically by the framework that wrap this adapter.

None
misc Misc

A Misc container for models that don't require optimizers, and other miscellaneous objects. These are passed into the wrapped hook at each training iteration.

None
default_containers MultipleContainers

The default set of containers to use, wrapped in a MultipleContainers object. If None then the default containers are defined in self.get_default_containers

None
key_enforcer KeyEnforcer

A KeyEnforcer object. If None, then self.get_key_enforcer is used.

None
inference_fn

A function that takes in:

  • x: the input to the model
  • domain: an integer representing the domain of the data
  • models: a dictionary of models, i.e. self.models
  • misc: a dictionary of misc objects, i.e. self.misc
None
before_training_starts

A function that takes in this adapter and returns another function that is optionally called by a framework wrapper before training starts.

None
hook_kwargs Dict[str, Any]

A dictionary of keyword arguments that will be passed into the wrapped hook during initialization.

None
Source code in pytorch_adapt\adapters\base_adapter.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(
    self,
    models: Models = None,
    optimizers: Optimizers = None,
    lr_schedulers: LRSchedulers = None,
    misc: Misc = None,
    default_containers: MultipleContainers = None,
    key_enforcer: KeyEnforcer = None,
    inference_fn=None,
    before_training_starts=None,
    hook_kwargs: Dict[str, Any] = None,
):
    """
    Arguments:
        models: A [```Models```][pytorch_adapt.containers.Models] container.
            The models will be passed to the wrapped hook at each
            training iteration.
        optimizers: An [```Optimizers```][pytorch_adapt.containers.Optimizers] container.
            The optimizers will be passed to the wrapped hook at each
            training iteration.
        lr_schedulers: An [```LRSchedulers```][pytorch_adapt.containers.LRSchedulers] container.
            The lr schedulers are called automatically by the
            [```framework```](../frameworks/index.md) that wrap this adapter.
        misc: A [```Misc```][pytorch_adapt.containers.Misc] container for models
            that don't require optimizers, and other miscellaneous objects.
            These are passed into the wrapped hook at each training iteration.
        default_containers: The default set of containers to use, wrapped in a
            [```MultipleContainers```][pytorch_adapt.containers.MultipleContainers] object.
            If ```None``` then the default containers are defined in
            [```self.get_default_containers```][pytorch_adapt.adapters.BaseAdapter.get_default_containers]
        key_enforcer: A [```KeyEnforcer```][pytorch_adapt.containers.KeyEnforcer] object.
            If ```None```, then [```self.get_key_enforcer```][pytorch_adapt.adapters.BaseAdapter.get_key_enforcer]
            is used.
        inference_fn: A function that takes in:

            - ```x```: the input to the model
            - ```domain```: an integer representing the domain of the data
            - ```models```: a dictionary of models, i.e. ```self.models```
            - ```misc```: a dictionary of misc objects, i.e. ```self.misc```

        before_training_starts: A function that takes in this adapter and returns another
            function that is optionally called by a framework wrapper before training starts.
        hook_kwargs: A dictionary of keyword arguments that will be
            passed into the wrapped hook during initialization.
    """
    containers = c_f.default(default_containers, self.get_default_containers, {})
    self.key_enforcer = c_f.default(key_enforcer, self.get_key_enforcer, {})
    self.before_training_starts = c_f.class_default(
        self, before_training_starts, self.before_training_starts_default
    )

    containers.merge(
        models=models,
        optimizers=optimizers,
        lr_schedulers=lr_schedulers,
        misc=misc,
    )

    hook_kwargs = c_f.default(hook_kwargs, {})
    self.init_containers_and_check_keys(containers)
    self.init_hook(hook_kwargs)
    self.inference_fn = c_f.default(inference_fn, default_fn)

get_default_containers()

Returns:

Type Description
MultipleContainers

The default set of containers. Consists of an Optimizers container using the default of Adam with lr 0.0001.

Source code in pytorch_adapt\adapters\base_adapter.py
127
128
129
130
131
132
133
134
135
136
def get_default_containers(self) -> MultipleContainers:
    """
    Returns:
        The default set of containers. Consists of an
            [Optimizers][pytorch_adapt.containers.Optimizers]
            container using the [default][pytorch_adapt.adapters.utils.default_optimizer_tuple]
            of Adam with lr 0.0001.
    """
    optimizers = Optimizers(default_optimizer_tuple())
    return MultipleContainers(optimizers=optimizers)

get_key_enforcer() abstractmethod

Returns:

Type Description
KeyEnforcer

The default KeyEnforcer.

Source code in pytorch_adapt\adapters\base_adapter.py
138
139
140
141
142
143
144
@abstractmethod
def get_key_enforcer(self) -> KeyEnforcer:
    """
    Returns:
        The default KeyEnforcer.
    """
    pass

inference(x, domain=None)

Parameters:

Name Type Description Default
x torch.Tensor

The input to the model

required
domain int

An optional integer indicating the domain.

None

Returns:

Type Description
Tuple[torch.Tensor, torch.Tensor]

Features and logits

Source code in pytorch_adapt\adapters\base_adapter.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def inference(
    self, x: torch.Tensor, domain: int = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        x: The input to the model
        domain: An optional integer indicating the domain.

    Returns:
        Features and logits
    """
    return self.inference_fn(
        x=x,
        domain=domain,
        models=self.models,
        misc=self.misc,
    )

init_containers_and_check_keys(containers)

Called in __init__ before init_hook.

Source code in pytorch_adapt\adapters\base_adapter.py
158
159
160
161
162
163
164
165
166
def init_containers_and_check_keys(self, containers):
    """
    Called in ```__init__``` before
    [```init_hook```][pytorch_adapt.adapters.BaseAdapter.init_hook].
    """
    containers.create()
    self.key_enforcer.check(containers)
    for k, v in containers.items():
        setattr(self, k, v)

init_hook() abstractmethod

self.hook is initialized here.

Source code in pytorch_adapt\adapters\base_adapter.py
146
147
148
149
150
151
@abstractmethod
def init_hook(self):
    """
    ```self.hook``` is initialized here.
    """
    pass

training_step(batch, **kwargs)

Calls the wrapped hook at each iteration during training.

Parameters:

Name Type Description Default
batch Dict[str, Any]

A dictionary containing training data.

required
**kwargs

Any other data that will be passed into the hook.

{}

Returns:

Type Description
Dict[str, Dict[str, float]]

A two-level dictionary

  • the outer level is associated with a particular optimization step (relevant for GAN architectures)

  • the inner level contains the loss components.

Source code in pytorch_adapt\adapters\base_adapter.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def training_step(
    self, batch: Dict[str, Any], **kwargs
) -> Dict[str, Dict[str, float]]:
    """
    Calls the wrapped hook at each iteration during training.
    Arguments:
        batch: A dictionary containing training data.
        **kwargs: Any other data that will be passed into the hook.
    Returns:
        A two-level dictionary

            - the outer level is associated with a particular optimization step
                (relevant for GAN architectures)

            - the inner level contains the loss components.
    """
    combined = c_f.assert_dicts_are_disjoint(
        self.models, self.misc, with_opt(self.optimizers), batch, kwargs
    )
    _, losses = self.hook(combined)
    return losses

BaseGCAdapter

Bases: BaseAdapter

Base class for adapters that use a Generator and Classifier.

Source code in pytorch_adapt\adapters\base_adapter.py
188
189
190
191
192
193
194
195
196
197
class BaseGCAdapter(BaseAdapter):
    """
    Base class for adapters that use a Generator and Classifier.
    """

    def get_key_enforcer(self) -> KeyEnforcer:
        return KeyEnforcer(
            models=["G", "C"],
            optimizers=["G", "C"],
        )

BaseGCDAdapter

Bases: BaseAdapter

Base class for adapters that use a Generator, Classifier, and Discriminator.

Source code in pytorch_adapt\adapters\base_adapter.py
176
177
178
179
180
181
182
183
184
185
class BaseGCDAdapter(BaseAdapter):
    """
    Base class for adapters that use a Generator, Classifier, and Discriminator.
    """

    def get_key_enforcer(self) -> KeyEnforcer:
        return KeyEnforcer(
            models=["G", "C", "D"],
            optimizers=["G", "C", "D"],
        )