gan
GANHook
¶
Bases: BaseWrapperHook
A generic GAN architecture for domain adaptation. This includes the model optimization steps.
Source code in pytorch_adapt\hooks\gan.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
|
__init__(d_opts, g_opts, d_weighter=None, d_reducer=None, g_weighter=None, g_reducer=None, pre_d=None, post_d=None, pre_g=None, post_g=None, use_logits=False, disc_hook=None, gen_hook=None, disc_f_hook=None, gen_f_hook=None, disc_d_hook=None, gen_d_hook=None, c_hook=None, disc_conditions=None, disc_alts=None, gen_conditions=None, gen_alts=None, disc_domains=None, gen_domains=None, disc_domain_loss_fn=None, gen_domain_loss_fn=None, **kwargs)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
d_opts |
List of optimizers for the D phase. |
required | |
g_opts |
List of optimizers for the G phase. |
required | |
d_weighter |
A loss weighter for the D phase.
If |
None
|
|
d_reducer |
A loss reducer for the D phase.
If |
None
|
|
g_weighter |
A loss weighter for the G phase.
If |
None
|
|
g_reducer |
A loss reducer for the G phase.
If |
None
|
|
pre_d |
List of hooks that will be executed at the very beginning of the D phase. |
None
|
|
post_d |
List of hooks that will be executed at the end of the D phase, but before the optimizers are called. |
None
|
|
pre_g |
List of hooks that will be executed at the very beginning of the G phase. |
None
|
|
post_g |
List of hooks that will be executed at the end of the G phase, but before the optimizers are called. |
None
|
|
use_logits |
If |
False
|
|
disc_hook |
The hook used for computing the discriminator's
domain loss. If |
None
|
|
gen_hook |
The hook used for computing the generator's
domain loss. If |
None
|
|
c_hook |
The hook used for computing the classifiers's loss.
If |
None
|
|
disc_conditions |
The condition hooks used in the |
None
|
|
disc_alts |
The alt hooks used in the |
None
|
|
gen_conditions |
The condition hooks used in the |
None
|
|
gen_alts |
The alt hooks used in the |
None
|
|
disc_domains |
The domains used to compute the discriminator's
domain loss. If |
None
|
|
gen_domains |
The domains used to compute the generators's
domain loss. If |
None
|
|
disc_domain_loss_fn |
The loss function used to compute the
discriminator's domain loss. If |
None
|
|
gen_domain_loss_fn |
The loss function used to compute the
generator's domain loss. If |
None
|
Source code in pytorch_adapt\hooks\gan.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
|