Skip to content

base

BaseConditionHook

Bases: BaseHook

The base class for hooks that return a boolean

Source code in pytorch_adapt\hooks\base.py
176
177
178
179
180
181
182
183
184
185
class BaseConditionHook(BaseHook):
    """The base class for hooks that return a boolean"""

    def _loss_keys(self):
        """"""
        return []

    def _out_keys(self):
        """"""
        return []

BaseHook

Bases: ABC

All hooks extend BaseHook

Source code in pytorch_adapt\hooks\base.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class BaseHook(ABC):
    """All hooks extend ```BaseHook```"""

    def __init__(
        self,
        loss_prefix: str = "",
        loss_suffix: str = "",
        out_prefix: str = "",
        out_suffix: str = "",
        key_map: Dict[str, str] = None,
    ):
        """
        Arguments:
            loss_prefix: prepended to all new loss keys
            loss_suffix: appended to all new loss keys
            out_prefix: prepended to all new output keys
            out_suffix: appended to all new output keys
            key_map: a mapping from ```input_key``` to ```new_key```.
                For example, if key_map = {"A": "B"}, and the input dict to ```__call__``` is {"A": 5},
                then the input will be converted to {"B": 5} before being consumed. Before exiting ```__call__```,
                the mapping is undone so the input context is preserved.
                In other words, {"B": 5} will be converted back to {"A": 5}.
        """
        if any(
            not isinstance(x, str)
            for x in [loss_prefix, loss_suffix, out_prefix, out_suffix]
        ):
            raise TypeError("loss prefix/suffix and out prefix/suffix must be strings")
        self.loss_prefix = loss_prefix
        self.loss_suffix = loss_suffix
        self.out_prefix = out_prefix
        self.out_suffix = out_suffix
        self.key_map = c_f.default(key_map, {})
        self.in_keys = []
        self.logger = HookLogger(c_f.cls_name(self))

    def __call__(self, inputs, losses=None):
        self.logger("__call__")
        losses = c_f.default(losses, {})
        try:
            inputs = c_f.map_keys(inputs, self.key_map)
            x = self.call(inputs, losses)
            if isinstance(x, (bool, np.bool_)):
                self.logger.reset()
                return x
            elif isinstance(x, tuple):
                outputs, losses = x
                outputs = replace_mapped_keys(outputs, self.key_map)
                inputs = replace_mapped_keys(inputs, self.key_map)
                outputs = wrap_keys(outputs, self.out_prefix, self.out_suffix)
                losses = wrap_keys(losses, self.loss_prefix, self.loss_suffix)
                self.check_losses_and_outputs(outputs, losses, inputs)
                self.logger.reset()
                return outputs, losses
            else:
                raise TypeError(
                    f"Output is of type {type(x)}, but should be bool or tuple"
                )
        except Exception as e:
            c_f.add_error_message(e, f"in {self.logger.str}\n", prepend=True)
            self.logger.reset()
            raise

    @abstractmethod
    def call(
        self, inputs: Dict[str, Any], losses: Dict[str, Any]
    ) -> Union[Tuple[Dict[str, Any], Dict[str, Any]], bool]:
        """
        This gets called by ```__call__``` and must be implemented by the child class.
        Arguments:
            inputs: holds data and models
            losses: previously computed losses
        Returns:
            Either a tuple of ```(outputs, losses)``` that will be merged with the input context,
                or a boolean
        """
        pass

    @abstractmethod
    def _loss_keys(self) -> List[str]:
        """
        This must be implemented by the child class
        Returns:
            The names of the losses that will be added to the context.
        """
        pass

    @property
    def loss_keys(self):
        return list(
            set(wrap_keys(self._loss_keys(), self.loss_prefix, self.loss_suffix))
        )

    @abstractmethod
    def _out_keys(self) -> List[str]:
        """
        This must be implemented by the child class
        Returns:
            The names of the outputs that will be added to the context.
        """
        pass

    @property
    def out_keys(self):
        x = replace_mapped_keys(self._out_keys(), self.key_map)
        return list(set(wrap_keys(x, self.out_prefix, self.out_suffix)))

    def set_in_keys(self, in_keys):
        self.in_keys = in_keys

    def __repr__(self):
        return c_f.nice_repr(self, self.extra_repr(), self.children_repr())

    def extra_repr(self):
        return ""

    def children_repr(self):
        all_hooks = c_f.attrs_of_type(self, BaseHook)
        all_modules = c_f.attrs_of_type(self, torch.nn.Module)
        return c_f.assert_dicts_are_disjoint(all_hooks, all_modules)

    def check_losses_and_outputs(self, outputs, losses, inputs):
        check_keys_are_present(self, self.loss_keys, [losses], "loss_keys", "losses")
        check_keys_are_present(
            self, self.out_keys, [inputs, outputs], "out_keys", "inputs or outputs"
        )
        check_keys_are_present(self, losses, self.loss_keys, "loss_keys", "losses")
        check_keys_are_present(self, outputs, self.out_keys, "outputs", "out_keys")

__init__(loss_prefix='', loss_suffix='', out_prefix='', out_suffix='', key_map=None)

Parameters:

Name Type Description Default
loss_prefix str

prepended to all new loss keys

''
loss_suffix str

appended to all new loss keys

''
out_prefix str

prepended to all new output keys

''
out_suffix str

appended to all new output keys

''
key_map Dict[str, str]

a mapping from input_key to new_key. For example, if key_map = {"A": "B"}, and the input dict to __call__ is {"A": 5}, then the input will be converted to {"B": 5} before being consumed. Before exiting __call__, the mapping is undone so the input context is preserved. In other words, {"B": 5} will be converted back to {"A": 5}.

None
Source code in pytorch_adapt\hooks\base.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    loss_prefix: str = "",
    loss_suffix: str = "",
    out_prefix: str = "",
    out_suffix: str = "",
    key_map: Dict[str, str] = None,
):
    """
    Arguments:
        loss_prefix: prepended to all new loss keys
        loss_suffix: appended to all new loss keys
        out_prefix: prepended to all new output keys
        out_suffix: appended to all new output keys
        key_map: a mapping from ```input_key``` to ```new_key```.
            For example, if key_map = {"A": "B"}, and the input dict to ```__call__``` is {"A": 5},
            then the input will be converted to {"B": 5} before being consumed. Before exiting ```__call__```,
            the mapping is undone so the input context is preserved.
            In other words, {"B": 5} will be converted back to {"A": 5}.
    """
    if any(
        not isinstance(x, str)
        for x in [loss_prefix, loss_suffix, out_prefix, out_suffix]
    ):
        raise TypeError("loss prefix/suffix and out prefix/suffix must be strings")
    self.loss_prefix = loss_prefix
    self.loss_suffix = loss_suffix
    self.out_prefix = out_prefix
    self.out_suffix = out_suffix
    self.key_map = c_f.default(key_map, {})
    self.in_keys = []
    self.logger = HookLogger(c_f.cls_name(self))

call(inputs, losses) abstractmethod

This gets called by __call__ and must be implemented by the child class.

Parameters:

Name Type Description Default
inputs Dict[str, Any]

holds data and models

required
losses Dict[str, Any]

previously computed losses

required

Returns:

Type Description
Union[Tuple[Dict[str, Any], Dict[str, Any]], bool]

Either a tuple of (outputs, losses) that will be merged with the input context, or a boolean

Source code in pytorch_adapt\hooks\base.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@abstractmethod
def call(
    self, inputs: Dict[str, Any], losses: Dict[str, Any]
) -> Union[Tuple[Dict[str, Any], Dict[str, Any]], bool]:
    """
    This gets called by ```__call__``` and must be implemented by the child class.
    Arguments:
        inputs: holds data and models
        losses: previously computed losses
    Returns:
        Either a tuple of ```(outputs, losses)``` that will be merged with the input context,
            or a boolean
    """
    pass

BaseWrapperHook

Bases: BaseHook

A simple wrapper for calling self.hook, which should be defined in the child's __init__ function.

Source code in pytorch_adapt\hooks\base.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class BaseWrapperHook(BaseHook):
    """A simple wrapper for calling ```self.hook```,
    which should be defined in the child's ```__init__``` function."""

    def call(self, *args, **kwargs):
        """"""
        return self.hook(*args, **kwargs)

    def _loss_keys(self):
        """"""
        return self.hook.loss_keys

    def _out_keys(self):
        """"""
        return self.hook.out_keys