Skip to content

features

BaseFeaturesHook

Bases: BaseHook

This hook:

  1. Checks to see if specific tensors are in the context
  2. Exits if the tensors are already in the context
  3. Otherwise computes those tensors using the appropriate inputs and models, and adds them to the context.
Source code in pytorch_adapt\hooks\features.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class BaseFeaturesHook(BaseHook):
    """
    This hook:

    1. Checks to see if specific tensors are in the context
    2. Exits if the tensors are already in the context
    3. Otherwise computes those tensors using the appropriate
        inputs and models, and adds them to the context.
    """

    def __init__(
        self,
        model_name: str,
        in_suffixes: List[str] = None,
        out_suffixes: List[str] = None,
        domains: List[str] = None,
        detach: bool = False,
        **kwargs,
    ):
        """
        Arguments:
            model_name: The name of the model that will
                be used to compute any missing tensors.

            in_suffixes: The suffixes of the names of the inputs
                to the model. For example if:

                - ```domains = ["src", "target"]```
                - ```in_suffixes = ["_imgs_features"]```

                then the model will be given

                - ```["src_imgs_features", "target_imgs_features"]```.

            out_suffixes: The suffixes of the names of the outputs
                of the model. Output suffixes are appended to the input name.
                For example, if

                - ```domains = ["src", "target"]```
                - ```in_suffixes = ["_imgs_features"]```
                - ```out_suffixes = ["_logits"]```

                then the output keys will be

                - ```["src_imgs_features_logits", "target_imgs_features_logits"]```

            domains: The names of the domains to use. If ```None```,
                this defaults to ```["src", "target"]```.

            detach: If ```True```, then the output will be detached
                from the autograd graph. Any output that is detached
                will have ```"_detached"``` appended to its name in the
                context.
        """

        super().__init__(**kwargs)
        self.model_name = model_name
        self.domains = c_f.default(domains, ["src", "target"])
        self.init_detach_mode(detach)
        self.init_suffixes(in_suffixes, out_suffixes)

    def call(self, inputs, losses):
        """"""
        outputs = {}
        for domain in self.domains:
            self.logger(f"Getting {domain}")
            detach = self.check_grad_mode(domain)
            func = self.mode_detached if detach else self.mode_with_grad
            in_keys = c_f.filter(self.in_keys, f"^{domain}")
            func(inputs, outputs, domain, in_keys)

        self.check_outputs_requires_grad(outputs)
        return outputs, {}

    def check_grad_mode(self, domain):
        detach = self.detach[domain]
        if not torch.is_grad_enabled():
            if not detach:
                raise ValueError(
                    f"detach[{domain}] == {detach} but grad is not enabled"
                )
        return detach

    def check_outputs_requires_grad(self, outputs):
        for k, v in outputs.items():
            if k.endswith("detached") and c_f.requires_grad(v, does=True):
                raise TypeError(f"{k} ends with 'detached' but tensor requires grad")
            if not k.endswith("detached") and c_f.requires_grad(v, does=False):
                raise TypeError(
                    f"{k} doesn't end in 'detached' but tensor doesn't require grad"
                )

    def mode_with_grad(self, inputs, outputs, domain, in_keys):
        output_keys = c_f.filter(self._out_keys(), f"^{domain}")
        output_vals = self.get_kwargs(inputs, output_keys)
        self.add_if_new(
            outputs, output_keys, output_vals, inputs, self.model_name, in_keys, domain
        )
        return output_keys, output_vals

    def mode_detached(self, inputs, outputs, domain, in_keys):
        curr_out_keys = c_f.filter(self._out_keys(), f"^{domain}")
        self.try_existing_detachable(inputs, outputs, curr_out_keys)
        remaining_out_keys = [
            k for k in curr_out_keys if k not in set().union(inputs, outputs)
        ]
        if len(remaining_out_keys) > 0:
            output_vals = self.get_kwargs(inputs, remaining_out_keys)
            with torch.no_grad():
                self.add_if_new(
                    outputs,
                    remaining_out_keys,
                    output_vals,
                    inputs,
                    self.model_name,
                    in_keys,
                    domain,
                )

    def add_if_new(
        self, outputs, full_key, output_vals, inputs, model_name, in_keys, domain
    ):
        c_f.add_if_new(
            outputs,
            full_key,
            output_vals,
            inputs,
            model_name,
            in_keys,
            logger=self.logger,
        )

    def create_keys(self, domain, suffix, starting_keys=None, detach=False):
        if starting_keys is None:
            full_keys = [f"{domain}{x}" for x in suffix]
        else:
            if len(starting_keys) > 1:
                starting_keys = self.join_keys(starting_keys)
            if len(suffix) > 1:
                starting_keys = starting_keys * len(suffix)
            full_keys = [f"{k}{x}" for k, x in zip(starting_keys, suffix)]
        if detach:
            full_keys = self.add_detached_string(full_keys)
        return full_keys

    def get_kwargs(self, inputs, keys):
        return [inputs.get(k) for k in keys]

    def try_existing_detachable(self, inputs, outputs, curr_out_keys):
        for k in curr_out_keys:
            if k in inputs or k in outputs:
                continue
            curr_regex = self.detachable_regex[k]
            success = self.try_existing_detachable_in_dict(
                curr_regex, inputs, outputs, k
            )
            if not success:
                self.try_existing_detachable_in_dict(curr_regex, outputs, outputs, k)

    def try_existing_detachable_in_dict(self, regex, in_dict, outputs, new_k):
        for k, v in in_dict.items():
            if regex.search(k) and v is not None:
                outputs[new_k] = v.detach()
                return True
        return False

    def add_detached_string(self, keys):
        # delete existing detached string, then append to the very end
        # for example, if computing detached logits for: src_imgs_features_detached
        # 1. src_imgs_features_detached_logits --> src_imgs_features_logits
        # 2. src_imgs_features_logits --> src_imgs_features_logits_detached
        keys = [k.replace("_detached", "") for k in keys]
        return [f"{k}_detached" for k in keys]

    def join_keys(self, keys):
        return ["_AND_".join(keys)]

    def init_detach_mode(self, detach):
        if isinstance(detach, dict):
            if any(not isinstance(v, bool) for v in detach.values()):
                raise TypeError("if detach is a dict, values must be bools")
            self.detach = detach
        elif isinstance(detach, bool):
            self.detach = {k: detach for k in self.domains}
        else:
            raise TypeError("detach must be a bool or a dict of bools")

    def init_suffixes(self, in_suffixes, out_suffixes):
        self.in_suffixes = in_suffixes
        self.out_suffixes = out_suffixes
        in_keys = []
        for domain in self.domains:
            in_keys.extend(self.create_keys(domain, in_suffixes))
        self.set_in_keys(in_keys)

    def set_in_keys(self, in_keys):
        super().set_in_keys(in_keys)
        self.all_out_keys = []
        for domain in self.domains:
            curr_in_keys = c_f.filter(self.in_keys, f"^{domain}")
            curr_out_keys = self.create_keys(
                domain, self.out_suffixes, curr_in_keys, detach=self.detach[domain]
            )
            self.all_out_keys.extend(curr_out_keys)

        # strings with '_detached' optional and anywhere
        self.detachable_regex = {
            k: re.compile(
                f"^{k.replace('_detached', '').replace('_', '(_detached)?_')}$"
            )
            for k in self.all_out_keys
        }

    def _loss_keys(self):
        """"""
        return []

    def _out_keys(self):
        """"""
        return self.all_out_keys

    def extra_repr(self):
        return c_f.extra_repr(self, ["model_name", "domains", "detach"])

__init__(model_name, in_suffixes=None, out_suffixes=None, domains=None, detach=False, **kwargs)

Parameters:

Name Type Description Default
model_name str

The name of the model that will be used to compute any missing tensors.

required
in_suffixes List[str]

The suffixes of the names of the inputs to the model. For example if:

  • domains = ["src", "target"]
  • in_suffixes = ["_imgs_features"]

then the model will be given

  • ["src_imgs_features", "target_imgs_features"].
None
out_suffixes List[str]

The suffixes of the names of the outputs of the model. Output suffixes are appended to the input name. For example, if

  • domains = ["src", "target"]
  • in_suffixes = ["_imgs_features"]
  • out_suffixes = ["_logits"]

then the output keys will be

  • ["src_imgs_features_logits", "target_imgs_features_logits"]
None
domains List[str]

The names of the domains to use. If None, this defaults to ["src", "target"].

None
detach bool

If True, then the output will be detached from the autograd graph. Any output that is detached will have "_detached" appended to its name in the context.

False
Source code in pytorch_adapt\hooks\features.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def __init__(
    self,
    model_name: str,
    in_suffixes: List[str] = None,
    out_suffixes: List[str] = None,
    domains: List[str] = None,
    detach: bool = False,
    **kwargs,
):
    """
    Arguments:
        model_name: The name of the model that will
            be used to compute any missing tensors.

        in_suffixes: The suffixes of the names of the inputs
            to the model. For example if:

            - ```domains = ["src", "target"]```
            - ```in_suffixes = ["_imgs_features"]```

            then the model will be given

            - ```["src_imgs_features", "target_imgs_features"]```.

        out_suffixes: The suffixes of the names of the outputs
            of the model. Output suffixes are appended to the input name.
            For example, if

            - ```domains = ["src", "target"]```
            - ```in_suffixes = ["_imgs_features"]```
            - ```out_suffixes = ["_logits"]```

            then the output keys will be

            - ```["src_imgs_features_logits", "target_imgs_features_logits"]```

        domains: The names of the domains to use. If ```None```,
            this defaults to ```["src", "target"]```.

        detach: If ```True```, then the output will be detached
            from the autograd graph. Any output that is detached
            will have ```"_detached"``` appended to its name in the
            context.
    """

    super().__init__(**kwargs)
    self.model_name = model_name
    self.domains = c_f.default(domains, ["src", "target"])
    self.init_detach_mode(detach)
    self.init_suffixes(in_suffixes, out_suffixes)

CombinedFeaturesHook

Bases: BaseFeaturesHook

Default input/output context names:

  • Model: "feature_combiner"
  • Inputs:
    ["src_imgs_features",
    "src_imgs_features_logits",
    "target_imgs_features",
    "target_imgs_features_logits"]
    
  • Outputs:
    ["src_imgs_features_AND_src_imgs_features_logits_combined",
    "target_imgs_features_AND_target_imgs_features_logits_combined"]
    
Source code in pytorch_adapt\hooks\features.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
class CombinedFeaturesHook(BaseFeaturesHook):
    """
    Default input/output context names:

    - Model: ```"feature_combiner"```
    - Inputs:
        ```
        ["src_imgs_features",
        "src_imgs_features_logits",
        "target_imgs_features",
        "target_imgs_features_logits"]
        ```
    - Outputs:
        ```
        ["src_imgs_features_AND_src_imgs_features_logits_combined",
        "target_imgs_features_AND_target_imgs_features_logits_combined"]
        ```
    """

    def __init__(
        self,
        in_suffixes=None,
        out_suffixes=None,
        **kwargs,
    ):
        in_suffixes = c_f.default(
            in_suffixes, ["_imgs_features", "_imgs_features_logits"]
        )
        out_suffixes = c_f.default(out_suffixes, ["_combined"])
        super().__init__(
            model_name="feature_combiner",
            in_suffixes=in_suffixes,
            out_suffixes=out_suffixes,
            **kwargs,
        )

DLogitsHook

Bases: BaseFeaturesHook

Default input/output context names:

  • Model: "D"
  • Inputs:
    ["src_imgs_features",
    "target_imgs_features"]
    
  • Outputs:
    ["src_imgs_features_dlogits",
    "target_imgs_features_dlogits"]
    
Source code in pytorch_adapt\hooks\features.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
class DLogitsHook(BaseFeaturesHook):
    """
    Default input/output context names:

    - Model: ```"D"```
    - Inputs:
        ```
        ["src_imgs_features",
        "target_imgs_features"]
        ```
    - Outputs:
        ```
        ["src_imgs_features_dlogits",
        "target_imgs_features_dlogits"]
        ```
    """

    def __init__(
        self,
        model_name="D",
        in_suffixes=None,
        out_suffixes=None,
        **kwargs,
    ):
        in_suffixes = c_f.default(in_suffixes, ["_imgs_features"])
        out_suffixes = c_f.default(out_suffixes, ["_dlogits"])
        super().__init__(
            model_name=model_name,
            in_suffixes=in_suffixes,
            out_suffixes=out_suffixes,
            **kwargs,
        )

FeaturesAndLogitsHook

Bases: FeaturesChainHook

Chains together FeaturesHook and LogitsHook.

Source code in pytorch_adapt\hooks\features.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
class FeaturesAndLogitsHook(FeaturesChainHook):
    """
    Chains together [```FeaturesHook```][pytorch_adapt.hooks.FeaturesHook]
    and [```LogitsHook```][pytorch_adapt.hooks.LogitsHook].
    """

    def __init__(
        self,
        domains: List[str] = None,
        detach_features: bool = False,
        detach_logits: bool = False,
        other_hooks: List[BaseHook] = None,
        **kwargs,
    ):
        """
        Arguments:
            domains: The domains used by both the features and logits hooks.
                If ```None```, it defaults to ```["src", "target"]```
            detach_features: If ```True```, returns features that are
                detached from the autograd graph.
            detach_logits: If ```True```, returns logits that are
                detached from the autograd graph.
            other_hooks: A list of hooks that will be called after
                the features and logits hooks.
        """
        features_hook = FeaturesHook(detach=detach_features, domains=domains)
        logits_hook = LogitsHook(detach=detach_logits, domains=domains)
        other_hooks = c_f.default(other_hooks, [])
        super().__init__(features_hook, logits_hook, *other_hooks, **kwargs)

__init__(domains=None, detach_features=False, detach_logits=False, other_hooks=None, **kwargs)

Parameters:

Name Type Description Default
domains List[str]

The domains used by both the features and logits hooks. If None, it defaults to ["src", "target"]

None
detach_features bool

If True, returns features that are detached from the autograd graph.

False
detach_logits bool

If True, returns logits that are detached from the autograd graph.

False
other_hooks List[BaseHook]

A list of hooks that will be called after the features and logits hooks.

None
Source code in pytorch_adapt\hooks\features.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def __init__(
    self,
    domains: List[str] = None,
    detach_features: bool = False,
    detach_logits: bool = False,
    other_hooks: List[BaseHook] = None,
    **kwargs,
):
    """
    Arguments:
        domains: The domains used by both the features and logits hooks.
            If ```None```, it defaults to ```["src", "target"]```
        detach_features: If ```True```, returns features that are
            detached from the autograd graph.
        detach_logits: If ```True```, returns logits that are
            detached from the autograd graph.
        other_hooks: A list of hooks that will be called after
            the features and logits hooks.
    """
    features_hook = FeaturesHook(detach=detach_features, domains=domains)
    logits_hook = LogitsHook(detach=detach_logits, domains=domains)
    other_hooks = c_f.default(other_hooks, [])
    super().__init__(features_hook, logits_hook, *other_hooks, **kwargs)

FeaturesChainHook

Bases: ChainHook

A special ChainHook for features hooks. It sets each sub-hook's in_keys using the previous sub-hook's out_keys.

Source code in pytorch_adapt\hooks\features.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
class FeaturesChainHook(ChainHook):
    """
    A special [```ChainHook```][pytorch_adapt.hooks.ChainHook]
    for features hooks. It sets each sub-hook's ```in_keys``` using
    the previous sub-hook's ```out_keys```.
    """

    def __init__(
        self,
        *hooks,
        **kwargs,
    ):
        for i in range(len(hooks) - 1):
            hooks[i + 1].set_in_keys(hooks[i].out_keys)
        super().__init__(*hooks, **kwargs)

FeaturesHook

Bases: BaseFeaturesHook

Default input/output context names:

  • Model: "G"
  • Inputs:
    ["src_imgs",
    "target_imgs"]
    
  • Outputs:
    ["src_imgs_features",
    "target_imgs_features"]
    
Source code in pytorch_adapt\hooks\features.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
class FeaturesHook(BaseFeaturesHook):
    """
    Default input/output context names:

    - Model: ```"G"```
    - Inputs:
        ```
        ["src_imgs",
        "target_imgs"]
        ```
    - Outputs:
        ```
        ["src_imgs_features",
        "target_imgs_features"]
        ```
    """

    def __init__(
        self,
        model_name="G",
        in_suffixes=None,
        out_suffixes=None,
        **kwargs,
    ):
        in_suffixes = c_f.default(in_suffixes, ["_imgs"])
        out_suffixes = c_f.default(out_suffixes, ["_features"])
        super().__init__(
            model_name=model_name,
            in_suffixes=in_suffixes,
            out_suffixes=out_suffixes,
            **kwargs,
        )

FeaturesWithGradAndDetachedHook

Bases: BaseWrapperHook

Default input/output context names:

  • Model: "G"
  • Inputs:
    ["src_imgs",
    "target_imgs"]
    
  • Outputs:
    ["src_imgs_features",
    "target_imgs_features",
    "src_imgs_features_detached",
    "target_imgs_features_detached"]
    
Source code in pytorch_adapt\hooks\features.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class FeaturesWithGradAndDetachedHook(BaseWrapperHook):
    """
    Default input/output context names:

    - Model: ```"G"```
    - Inputs:
        ```
        ["src_imgs",
        "target_imgs"]
        ```
    - Outputs:
        ```
        ["src_imgs_features",
        "target_imgs_features",
        "src_imgs_features_detached",
        "target_imgs_features_detached"]
        ```
    """

    def __init__(
        self,
        model_name="G",
        in_suffixes=None,
        out_suffixes=None,
        domains=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        hooks = []
        for detach in [False, True]:
            hooks.append(
                FeaturesHook(
                    model_name=model_name,
                    in_suffixes=in_suffixes,
                    out_suffixes=out_suffixes,
                    domains=domains,
                    detach=detach,
                    **kwargs,
                )
            )
        self.hook = ChainHook(*hooks)

FrozenModelHook

Bases: BaseWrapperHook

Sets model to eval() mode, and does all computations with gradients turned off.

Source code in pytorch_adapt\hooks\features.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
class FrozenModelHook(BaseWrapperHook):
    """
    Sets model to ```eval()``` mode, and does all
    computations with gradients turned off.
    """

    def __init__(self, hook: BaseHook, model_name: str, **kwargs):
        """
        Arguments:
            hook: The wrapped hook which computes all losses and outputs.
            model_name: The name of the model that will be set to eval() mode.
        """

        super().__init__(**kwargs)
        self.hook = hook
        self.model_name = model_name

    def call(self, inputs, losses):
        """"""
        model = inputs[self.model_name]
        model.eval()
        with torch.no_grad():
            return self.hook(inputs, losses)

__init__(hook, model_name, **kwargs)

Parameters:

Name Type Description Default
hook BaseHook

The wrapped hook which computes all losses and outputs.

required
model_name str

The name of the model that will be set to eval() mode.

required
Source code in pytorch_adapt\hooks\features.py
472
473
474
475
476
477
478
479
480
481
def __init__(self, hook: BaseHook, model_name: str, **kwargs):
    """
    Arguments:
        hook: The wrapped hook which computes all losses and outputs.
        model_name: The name of the model that will be set to eval() mode.
    """

    super().__init__(**kwargs)
    self.hook = hook
    self.model_name = model_name

LogitsHook

Bases: BaseFeaturesHook

Default input/output context names:

  • Model: "C"
  • Inputs:
    ["src_imgs_features",
    "target_imgs_features"]
    
  • Outputs:
    ["src_imgs_features_logits",
    "target_imgs_features_logits"]
    
Source code in pytorch_adapt\hooks\features.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
class LogitsHook(BaseFeaturesHook):
    """
    Default input/output context names:

    - Model: ```"C"```
    - Inputs:
        ```
        ["src_imgs_features",
        "target_imgs_features"]
        ```
    - Outputs:
        ```
        ["src_imgs_features_logits",
        "target_imgs_features_logits"]
        ```
    """

    def __init__(
        self,
        model_name="C",
        in_suffixes=None,
        out_suffixes=None,
        **kwargs,
    ):
        in_suffixes = c_f.default(in_suffixes, ["_imgs_features"])
        out_suffixes = c_f.default(out_suffixes, ["_logits"])
        super().__init__(
            model_name=model_name,
            in_suffixes=in_suffixes,
            out_suffixes=out_suffixes,
            **kwargs,
        )