dann
DANNHook
¶
Bases: BaseWrapperHook
Implementation of Domain-Adversarial Training of Neural Networks.
This includes the model optimization step.
Source code in pytorch_adapt\hooks\dann.py
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
|
__init__(opts, weighter=None, reducer=None, pre=None, pre_d=None, post_d=None, pre_g=None, post_g=None, gradient_reversal=None, gradient_reversal_weight=1, use_logits=False, f_hook=None, d_hook=None, c_hook=None, domain_loss_hook=None, d_hook_allowed='_dlogits$', **kwargs)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
opts |
List of optimizers for updating the models. |
required | |
weighter |
Weights the losses before backpropagation.
If |
None
|
|
reducer |
Reduces loss tensors.
If |
None
|
|
pre |
List of hooks that will be executed at the very beginning of each iteration. |
None
|
|
pre_d |
List of hooks that will be executed after gradient reversal, but before the domain loss. |
None
|
|
post_d |
List of hooks that will be executed after gradient reversal, and after the domain loss. |
None
|
|
pre_g |
List of hooks that will be executed outside of the gradient reversal step, and before the generator and classifier loss. |
None
|
|
post_g |
List of hooks that will be executed after the generator and classifier losses. |
None
|
|
gradient_reversal |
Called before all D hooks, including
|
None
|
|
use_logits |
If |
False
|
|
f_hook |
The hook used for computing features and logits.
If |
None
|
|
d_hook |
The hook used for computing discriminator logits.
If |
None
|
|
c_hook |
The hook used for computing the classifiers's loss.
If |
None
|
|
domain_loss_hook |
The hook used for computing the domain loss.
If |
None
|
|
d_hook_allowed |
A regex string that specifies the allowed output names of the discriminator block. |
'_dlogits$'
|
Source code in pytorch_adapt\hooks\dann.py
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
|