classification
CLossHook
¶
Bases: BaseWrapperHook
Computes a classification loss on the specified tensors. The default setting is to compute the cross entropy loss of the source domain logits.
Source code in pytorch_adapt\hooks\classification.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
__init__(loss_fn=None, detach_features=False, f_hook=None, **kwargs)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_fn |
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
The classification loss function. If |
None
|
detach_features |
bool
|
Whether or not to detach the features, from which logits are computed. |
False
|
f_hook |
BaseHook
|
The hook for computing logits. |
None
|
Source code in pytorch_adapt\hooks\classification.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
|
ClassifierHook
¶
Bases: BaseWrapperHook
This computes the classification loss and also optimizes the models.
Source code in pytorch_adapt\hooks\classification.py
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
|
FinetunerHook
¶
Bases: ClassifierHook
This is the same as
ClassifierHook
,
but it freezes the generator model ("G").
Source code in pytorch_adapt\hooks\classification.py
122 123 124 125 126 127 128 129 130 131 132 |
|
SoftmaxHook
¶
Bases: ApplyFnHook
Applies torch.nn.Softmax(dim=1)
to the
specified inputs.
Source code in pytorch_adapt\hooks\classification.py
18 19 20 21 22 23 24 25 |
|
SoftmaxLocallyHook
¶
Bases: BaseWrapperHook
Applies torch.nn.Softmax(dim=1)
to the
specified inputs, which are overwritten, but
only inside this hook.
Source code in pytorch_adapt\hooks\classification.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|
__init__(apply_to, *hooks, **kwargs)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
apply_to |
List[str]
|
list of names of tensors that softmax will be applied to. |
required |
*hooks |
BaseHook
|
the hooks that will receive the softmaxed tensors. |
()
|
Source code in pytorch_adapt\hooks\classification.py
35 36 37 38 39 40 41 42 43 44 45 |
|