aligners
AlignerHook
¶
Bases: BaseWrapperHook
Computes an alignment loss (e.g MMD) based on features from two domains.
Source code in pytorch_adapt\hooks\aligners.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
__init__(loss_fn=None, hook=None, layer='features', **kwargs)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fn |
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
a function that computes a distance
between two tensors. If |
None
|
hook |
BaseHook
|
the hook for computing features |
None
|
layer |
str
|
the layer for which the loss is computed. Must be
either |
'features'
|
Source code in pytorch_adapt\hooks\aligners.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
AlignerPlusCHook
¶
Bases: BaseWrapperHook
Computes an alignment loss plus a classification loss, and then optimizes the models.
Source code in pytorch_adapt\hooks\aligners.py
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
|
FeaturesLogitsAlignerHook
¶
Bases: BaseWrapperHook
This chains together an
AlignerHook
for
"features"
followed by an AlignerHook
for "logits"
.
Source code in pytorch_adapt\hooks\aligners.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
|
__init__(loss_fn=None, **kwargs)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fn |
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
The loss used by both aligner hooks. |
None
|
Source code in pytorch_adapt\hooks\aligners.py
110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
|
JointAlignerHook
¶
Bases: BaseWrapperHook
Computes a joint alignment loss (e.g Joint MMD) based on multiple features from two domains.
The default setting is to use the features and logits from the source and target domains.
Source code in pytorch_adapt\hooks\aligners.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|
__init__(loss_fn=None, hook=None, **kwargs)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fn |
Callable[[List[torch.Tensor], List[torch.Tensor]], torch.Tensor]
|
a function that computes a distance
between two lists of tensors. If |
None
|
hook |
BaseHook
|
the hook for computing features and logits |
None
|
Source code in pytorch_adapt\hooks\aligners.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
|