vat_loss
VATLoss
¶
Bases: torch.nn.Module
Implementation of the loss used in
- Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning
- A DIRT-T Approach to Unsupervised Domain Adaptation
Source code in pytorch_adapt\layers\vat_loss.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
|
__init__(num_power_iterations=1, xi=1e-06, epsilon=8.0)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_power_iterations |
int
|
The number of iterations for computing the approximation of the adversarial perturbation. |
1
|
xi |
float
|
The L2 norm of the the generated noise which is used in the process of creating the perturbation. |
1e-06
|
epsilon |
float
|
The L2 norm of the generated perturbation. |
8.0
|
Source code in pytorch_adapt\layers\vat_loss.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
|
forward(imgs, logits, model)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
imgs |
torch.Tensor
|
The input to the model |
required |
logits |
torch.Tensor
|
The model's logits computed from |
required |
model |
torch.nn.Module
|
The aforementioned model |
required |
Source code in pytorch_adapt\layers\vat_loss.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|