Skip to content

utils

ApplyFnHook

Bases: BaseHook

Applies a function to specific values of the context.

Source code in pytorch_adapt\hooks\utils.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
class ApplyFnHook(BaseHook):
    """
    Applies a function to specific values of the context.
    """

    def __init__(
        self, fn: Callable, apply_to: List[str], is_loss: bool = False, **kwargs
    ):
        """
        Arguments:
            fn: The function that will be applied to the inputs.
            apply_to: fn will be applied to ```inputs[k]``` for k in apply_to
            is_loss: If False, then the returned loss dictionary will be empty.
                Otherwise, the returned output dictionary will be empty.
        """
        super().__init__(**kwargs)
        self.fn = fn
        self.apply_to = apply_to
        self.is_loss = is_loss

    def call(self, inputs, losses):
        """"""
        x = c_f.extract(inputs, self.apply_to)
        outputs = {k: self.fn(v) for k, v in zip(self.apply_to, x)}
        if self.is_loss:
            return outputs, {}
        return outputs, {}

    def _loss_keys(self):
        """"""
        return self.apply_to if self.is_loss else []

    def _out_keys(self):
        """"""
        return [] if self.is_loss else self.apply_to

    def extra_repr(self):
        return c_f.extra_repr(self, ["apply_to"])

__init__(fn, apply_to, is_loss=False, **kwargs)

Parameters:

Name Type Description Default
fn Callable

The function that will be applied to the inputs.

required
apply_to List[str]

fn will be applied to inputs[k] for k in apply_to

required
is_loss bool

If False, then the returned loss dictionary will be empty. Otherwise, the returned output dictionary will be empty.

False
Source code in pytorch_adapt\hooks\utils.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def __init__(
    self, fn: Callable, apply_to: List[str], is_loss: bool = False, **kwargs
):
    """
    Arguments:
        fn: The function that will be applied to the inputs.
        apply_to: fn will be applied to ```inputs[k]``` for k in apply_to
        is_loss: If False, then the returned loss dictionary will be empty.
            Otherwise, the returned output dictionary will be empty.
    """
    super().__init__(**kwargs)
    self.fn = fn
    self.apply_to = apply_to
    self.is_loss = is_loss

AssertHook

Bases: BaseWrapperHook

Asserts that the output keys of a hook match a specified regex string

Source code in pytorch_adapt\hooks\utils.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
class AssertHook(BaseWrapperHook):
    """
    Asserts that the output keys of a hook match a specified regex string
    """

    def __init__(self, hook: BaseHook, allowed: str, **kwargs):
        """
        Arguments:
            hook: The wrapped hook
            allowed: The output dictionary of ```hook```
                must have keys that match the ```allowed``` regex.
        """
        super().__init__(**kwargs)
        self.hook = hook
        if not isinstance(allowed, str):
            raise TypeError("allowed must be a str")
        self.allowed = allowed

    def call(self, inputs, losses):
        """"""
        outputs, losses = self.hook(inputs, losses)
        self.assert_fn(outputs)
        return outputs, losses

    def assert_fn(self, outputs):
        filtered = c_f.filter(outputs, self.allowed)
        if len(filtered) != len(outputs):
            error_str = f"{c_f.cls_name(self.hook)} is producing outputs that don't match the allowed regex in {c_f.cls_name(self)}\n"
            error_str += f"output keys = {outputs.keys()}\n"
            error_str += f"regex filter = {self.allowed}"
            raise ValueError(error_str)

    def extra_repr(self):
        return c_f.extra_repr(self, ["allowed"])

__init__(hook, allowed, **kwargs)

Parameters:

Name Type Description Default
hook BaseHook

The wrapped hook

required
allowed str

The output dictionary of hook must have keys that match the allowed regex.

required
Source code in pytorch_adapt\hooks\utils.py
303
304
305
306
307
308
309
310
311
312
313
314
def __init__(self, hook: BaseHook, allowed: str, **kwargs):
    """
    Arguments:
        hook: The wrapped hook
        allowed: The output dictionary of ```hook```
            must have keys that match the ```allowed``` regex.
    """
    super().__init__(**kwargs)
    self.hook = hook
    if not isinstance(allowed, str):
        raise TypeError("allowed must be a str")
    self.allowed = allowed

ChainHook

Bases: BaseHook

Calls multiple hooks sequentially. The Nth hook receives the context accumulated through hooks 0 to N-1.

Source code in pytorch_adapt\hooks\utils.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class ChainHook(BaseHook):
    """
    Calls multiple hooks sequentially.
    The Nth hook receives the context accumulated through hooks 0 to N-1.
    """

    def __init__(
        self,
        *hooks: BaseHook,
        conditions: List[BaseConditionHook] = None,
        alts: List[BaseHook] = None,
        overwrite: Union[bool, List[int]] = False,
        **kwargs,
    ):
        """
        Arguments:
            *hooks: a sequence of hooks that will be called sequentially.
            conditions: an optional list of condition hooks.
                If conditions[i] returns False, then alts[i] is called. Otherwise hooks[i] is called.
            alts: an optional list of hooks that will be executed
                when the corresponding condition hook returns False
            overwrite: If True, then hooks will be allowed to overwrite keys in the context.
                If a list of integers, then the hooks at the specified indices
                will be allowed to overwrite keys in the context.

        """

        super().__init__(**kwargs)
        self.hooks = hooks
        self.conditions = c_f.default(
            conditions, [TrueHook() for _ in range(len(hooks))]
        )
        self.alts = c_f.default(
            alts, [ZeroLossHook(h.loss_keys, h.out_keys) for h in self.hooks]
        )
        self.check_alt_keys_match_hook_keys()
        if not isinstance(overwrite, (list, bool)):
            raise TypeError("overwrite must be a list or bool")
        self.overwrite = overwrite
        self.in_keys = self.hooks[0].in_keys

    def call(self, inputs, losses):
        """"""
        outputs, out_losses = {}, {}
        all_inputs, all_losses = inputs, losses
        prev_outputs, prev_losses = {}, {}
        for i, h in enumerate(self.hooks):
            self.check_overwrite(i, all_inputs, prev_outputs, self.overwrite)
            self.check_overwrite(i, all_losses, prev_losses, False)
            all_inputs = {**all_inputs, **prev_outputs}
            all_losses = {**all_losses, **prev_losses}
            if self.conditions[i](all_inputs, all_losses):
                x = h(all_inputs, all_losses)
            else:
                x = self.alts[i](all_inputs, all_losses)
            prev_outputs, prev_losses = x
            out_losses.update(prev_losses)
            outputs.update(prev_outputs)
        return outputs, out_losses

    def check_overlap(self, x, y, names):
        is_overlap, overlap = c_f.dicts_are_overlapping(x, y, return_overlap=True)
        if is_overlap:
            raise KeyError(
                f"overwrite is false, but {names[0]} and {names[1]} have overlapping keys: {overlap}"
            )

    def check_overwrite(self, i, kwargs, prev_outputs, overwrite):
        if not overwrite or (isinstance(overwrite, list) and i not in overwrite):
            self.check_overlap(kwargs, prev_outputs, ["kwargs", "prev_outputs"])

    def _loss_keys(self):
        """"""
        return c_f.join_lists([h.loss_keys for h in self.hooks])

    def _out_keys(self):
        """"""
        return c_f.join_lists([h.out_keys for h in self.hooks])

    @property
    def last_hook_out_keys(self):
        return self.hooks[-1].out_keys

    def check_alt_keys_match_hook_keys(self):
        for i in range(len(self.hooks)):
            h = self.hooks[i]
            a = self.alts[i]
            if (sorted(h.loss_keys) != sorted(a.loss_keys)) or (
                sorted(h.out_keys) != sorted(a.out_keys)
            ):
                raise ValueError(
                    "alt loss/out keys must be equal to hook loss/out keys"
                )

    def children_repr(self):
        x = super().children_repr()
        x["hooks"] = self.hooks
        if any(not isinstance(c, TrueHook) for c in self.conditions):
            x.update({"conditions": self.conditions, "alts": self.alts})
        return x

__init__(*hooks, conditions=None, alts=None, overwrite=False, **kwargs)

Parameters:

Name Type Description Default
*hooks BaseHook

a sequence of hooks that will be called sequentially.

()
conditions List[BaseConditionHook]

an optional list of condition hooks. If conditions[i] returns False, then alts[i] is called. Otherwise hooks[i] is called.

None
alts List[BaseHook]

an optional list of hooks that will be executed when the corresponding condition hook returns False

None
overwrite Union[bool, List[int]]

If True, then hooks will be allowed to overwrite keys in the context. If a list of integers, then the hooks at the specified indices will be allowed to overwrite keys in the context.

False
Source code in pytorch_adapt\hooks\utils.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def __init__(
    self,
    *hooks: BaseHook,
    conditions: List[BaseConditionHook] = None,
    alts: List[BaseHook] = None,
    overwrite: Union[bool, List[int]] = False,
    **kwargs,
):
    """
    Arguments:
        *hooks: a sequence of hooks that will be called sequentially.
        conditions: an optional list of condition hooks.
            If conditions[i] returns False, then alts[i] is called. Otherwise hooks[i] is called.
        alts: an optional list of hooks that will be executed
            when the corresponding condition hook returns False
        overwrite: If True, then hooks will be allowed to overwrite keys in the context.
            If a list of integers, then the hooks at the specified indices
            will be allowed to overwrite keys in the context.

    """

    super().__init__(**kwargs)
    self.hooks = hooks
    self.conditions = c_f.default(
        conditions, [TrueHook() for _ in range(len(hooks))]
    )
    self.alts = c_f.default(
        alts, [ZeroLossHook(h.loss_keys, h.out_keys) for h in self.hooks]
    )
    self.check_alt_keys_match_hook_keys()
    if not isinstance(overwrite, (list, bool)):
        raise TypeError("overwrite must be a list or bool")
    self.overwrite = overwrite
    self.in_keys = self.hooks[0].in_keys

EmptyHook

Bases: BaseHook

Returns two empty dictionaries.

Source code in pytorch_adapt\hooks\utils.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
class EmptyHook(BaseHook):
    """Returns two empty dictionaries."""

    def call(self, inputs, losses):
        """"""
        return {}, {}

    def _loss_keys(self):
        """"""
        return []

    def _out_keys(self):
        """"""
        return []

FalseHook

Bases: BaseConditionHook

Returns False

Source code in pytorch_adapt\hooks\utils.py
274
275
276
277
278
279
class FalseHook(BaseConditionHook):
    """Returns ```False```"""

    def call(self, inputs, losses):
        """"""
        return False

MultiplierHook

Bases: BaseWrapperHook

Multiplies every loss by a scalar

Source code in pytorch_adapt\hooks\utils.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class MultiplierHook(BaseWrapperHook):
    """
    Multiplies every loss by a scalar
    """

    def __init__(self, hook: BaseHook, m: float, **kwargs):
        """
        Arguments:
            hook: The losses of this hook will be multiplied by ```m```
            m: The scalar
        """
        super().__init__(**kwargs)
        self.hook = hook
        self.m = m

    def call(self, inputs, losses):
        """"""
        outputs, losses = self.hook(inputs, losses)
        losses = {k: v * self.m for k, v in losses.items()}
        return outputs, losses

    def extra_repr(self):
        return c_f.extra_repr(self, ["m"])

__init__(hook, m, **kwargs)

Parameters:

Name Type Description Default
hook BaseHook

The losses of this hook will be multiplied by m

required
m float

The scalar

required
Source code in pytorch_adapt\hooks\utils.py
339
340
341
342
343
344
345
346
347
def __init__(self, hook: BaseHook, m: float, **kwargs):
    """
    Arguments:
        hook: The losses of this hook will be multiplied by ```m```
        m: The scalar
    """
    super().__init__(**kwargs)
    self.hook = hook
    self.m = m

NotHook

Bases: BaseConditionHook

Returns the boolean negation of the wrapped hook.

Source code in pytorch_adapt\hooks\utils.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
class NotHook(BaseConditionHook):
    """Returns the boolean negation of the wrapped hook."""

    def __init__(self, hook: BaseConditionHook, **kwargs):
        """
        Arguments:
            hook: The condition hook that will be negated.
        """
        super().__init__(**kwargs)
        self.hook = hook

    def call(self, inputs, losses):
        """"""
        return not self.hook(inputs, losses)

__init__(hook, **kwargs)

Parameters:

Name Type Description Default
hook BaseConditionHook

The condition hook that will be negated.

required
Source code in pytorch_adapt\hooks\utils.py
285
286
287
288
289
290
291
def __init__(self, hook: BaseConditionHook, **kwargs):
    """
    Arguments:
        hook: The condition hook that will be negated.
    """
    super().__init__(**kwargs)
    self.hook = hook

OnlyNewOutputsHook

Bases: BaseWrapperHook

Returns only outputs that are not present in the input context. You should use this if you want to change the value of a key passed to self.hook, but not propagate that change to the outside.

Source code in pytorch_adapt\hooks\utils.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class OnlyNewOutputsHook(BaseWrapperHook):
    """
    Returns only outputs that are not present in the input context.
    You should use this if you want to change the value of
    a key passed to self.hook, but not propagate that change
    to the outside.
    """

    def __init__(self, hook: BaseHook, **kwargs):
        """
        Arguments:
            hook: The hook inside which changes to the context will be allowed.
        """
        super().__init__(**kwargs)
        self.hook = hook

    def call(self, inputs, losses):
        """"""
        outputs, losses = self.hook(inputs, losses)
        outputs = {k: outputs[k] for k in (outputs.keys() - inputs.keys())}
        c_f.assert_dicts_are_disjoint(inputs, outputs)
        return outputs, losses

__init__(hook, **kwargs)

Parameters:

Name Type Description Default
hook BaseHook

The hook inside which changes to the context will be allowed.

required
Source code in pytorch_adapt\hooks\utils.py
210
211
212
213
214
215
216
def __init__(self, hook: BaseHook, **kwargs):
    """
    Arguments:
        hook: The hook inside which changes to the context will be allowed.
    """
    super().__init__(**kwargs)
    self.hook = hook

ParallelHook

Bases: BaseHook

Calls multiple hooks while keeping contexts separate. The Nth hook receives the same context as hooks 0 to N-1. All the output contexts are merged at the end.

Source code in pytorch_adapt\hooks\utils.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class ParallelHook(BaseHook):
    """
    Calls multiple hooks while keeping contexts separate.
    The Nth hook receives the same context as hooks 0 to N-1.
    All the output contexts are merged at the end.
    """

    def __init__(self, *hooks: BaseHook, **kwargs):
        """
        Arguments:
            *hooks: a sequence of hooks that will be called sequentially,
                with each hook receiving the same initial context.
        """
        super().__init__(**kwargs)
        self.hooks = hooks
        self.in_keys = c_f.join_lists([h.in_keys for h in self.hooks])

    def call(self, inputs, losses):
        """"""
        outputs, out_losses = {}, {}
        for h in self.hooks:
            x = h(inputs, losses)
            outputs.update(x[0])
            out_losses.update(x[1])

        return outputs, out_losses

    def children_repr(self):
        x = super().children_repr()
        x.update({"hooks": self.hooks})
        return x

    def _loss_keys(self):
        """"""
        return c_f.join_lists([h.loss_keys for h in self.hooks])

    def _out_keys(self):
        """"""
        return c_f.join_lists([h.out_keys for h in self.hooks])

__init__(*hooks, **kwargs)

Parameters:

Name Type Description Default
*hooks BaseHook

a sequence of hooks that will be called sequentially, with each hook receiving the same initial context.

()
Source code in pytorch_adapt\hooks\utils.py
168
169
170
171
172
173
174
175
176
def __init__(self, *hooks: BaseHook, **kwargs):
    """
    Arguments:
        *hooks: a sequence of hooks that will be called sequentially,
            with each hook receiving the same initial context.
    """
    super().__init__(**kwargs)
    self.hooks = hooks
    self.in_keys = c_f.join_lists([h.in_keys for h in self.hooks])

RepeatHook

Bases: BaseHook

Executes the wrapped hook n times.

Source code in pytorch_adapt\hooks\utils.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
class RepeatHook(BaseHook):
    """
    Executes the wrapped hook ```n``` times.
    """

    def __init__(self, hook: BaseHook, n: int, keep_only_last: bool = False, **kwargs):
        """
        Arguments:
            hook: The hook that will be executed ```n``` times
            n: The number of times the hook will be executed.
            keep_only_last: If ```False```, the (outputs, losses) from each execution
                will be accumulated, and the keys will have the iteration number appended.
                If ```True```, then only the (outputs, losses) of the final execution will
                be kept.
        """
        super().__init__(**kwargs)
        self.hook = hook
        self.n = n
        self.keep_only_last = keep_only_last

    def call(self, inputs, losses):
        """"""
        outputs, losses = {}, {}
        for i in range(self.n):
            x = self.hook(inputs, losses)
            if self.keep_only_last and i == self.n - 1:
                outputs, losses = x
            else:
                outputs.update({f"{k}{i}": v for k, v in x[0].items()})
                losses.update({f"{k}{i}": v for k, v in x[1].items()})
        return outputs, losses

    def _loss_keys(self):
        """"""
        if self.keep_only_last:
            return self.hook.loss_keys
        else:
            return [f"{k}{i}" for k in self.hook.loss_keys for i in range(self.n)]

    def _out_keys(self):
        """"""
        if self.keep_only_last:
            return self.hook.out_keys
        else:
            return [f"{k}{i}" for k in self.hook.out_keys for i in range(self.n)]

    def extra_repr(self):
        return c_f.extra_repr(self, ["n", "keep_only_last"])

__init__(hook, n, keep_only_last=False, **kwargs)

Parameters:

Name Type Description Default
hook BaseHook

The hook that will be executed n times

required
n int

The number of times the hook will be executed.

required
keep_only_last bool

If False, the (outputs, losses) from each execution will be accumulated, and the keys will have the iteration number appended. If True, then only the (outputs, losses) of the final execution will be kept.

False
Source code in pytorch_adapt\hooks\utils.py
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def __init__(self, hook: BaseHook, n: int, keep_only_last: bool = False, **kwargs):
    """
    Arguments:
        hook: The hook that will be executed ```n``` times
        n: The number of times the hook will be executed.
        keep_only_last: If ```False```, the (outputs, losses) from each execution
            will be accumulated, and the keys will have the iteration number appended.
            If ```True```, then only the (outputs, losses) of the final execution will
            be kept.
    """
    super().__init__(**kwargs)
    self.hook = hook
    self.n = n
    self.keep_only_last = keep_only_last

TrueHook

Bases: BaseConditionHook

Returns True

Source code in pytorch_adapt\hooks\utils.py
266
267
268
269
270
271
class TrueHook(BaseConditionHook):
    """Returns ```True```"""

    def call(self, inputs, losses):
        """"""
        return True

ZeroLossHook

Bases: BaseHook

Returns only 0 losses and None outputs.

Source code in pytorch_adapt\hooks\utils.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class ZeroLossHook(BaseHook):
    """
    Returns only 0 losses and ```None``` outputs.
    """

    def __init__(self, loss_names: List[str], out_names: List[str], **kwargs):
        """
        Arguments:
            loss_names: The keys of the loss dictionary
                which will have ```tensor(0.)``` as its values.
            out_names: The keys of the output dictionary
                which will have ```None``` as its values.
        """
        super().__init__(**kwargs)
        self.loss_names = loss_names
        self.out_names = out_names

    def call(self, inputs, losses):
        """"""
        out_keys = set(self.out_names) - inputs.keys()
        return (
            {k: None for k in out_keys},
            {k: c_f.zero_loss() for k in self.loss_names},
        )

    def _loss_keys(self):
        """"""
        return self.loss_names

    def _out_keys(self):
        """"""
        return self.out_names

__init__(loss_names, out_names, **kwargs)

Parameters:

Name Type Description Default
loss_names List[str]

The keys of the loss dictionary which will have tensor(0.) as its values.

required
out_names List[str]

The keys of the output dictionary which will have None as its values.

required
Source code in pytorch_adapt\hooks\utils.py
28
29
30
31
32
33
34
35
36
37
38
def __init__(self, loss_names: List[str], out_names: List[str], **kwargs):
    """
    Arguments:
        loss_names: The keys of the loss dictionary
            which will have ```tensor(0.)``` as its values.
        out_names: The keys of the output dictionary
            which will have ```None``` as its values.
    """
    super().__init__(**kwargs)
    self.loss_names = loss_names
    self.out_names = out_names