Skip to content

vat_loss

VATLoss

Bases: torch.nn.Module

Implementation of the loss used in

Source code in pytorch_adapt\layers\vat_loss.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class VATLoss(torch.nn.Module):
    """
    Implementation of the loss used in

    - [Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning](https://arxiv.org/abs/1704.03976)
    - [A DIRT-T Approach to Unsupervised Domain Adaptation](https://arxiv.org/abs/1802.08735)
    """

    def __init__(
        self, num_power_iterations: int = 1, xi: float = 1e-6, epsilon: float = 8.0
    ):
        """
        Arguments:
            num_power_iterations: The number of iterations for
                computing the approximation of the adversarial perturbation.
            xi: The L2 norm of the the generated noise which is used
                in the process of creating the perturbation.
            epsilon: The L2 norm of the generated perturbation.
        """
        super().__init__()
        self.num_power_iterations = num_power_iterations
        self.xi = xi
        self.epsilon = epsilon
        self.kl_div = torch.nn.KLDivLoss(reduction="batchmean")
        pml_cf.add_to_recordable_attributes(
            self, list_of_names=["num_power_iterations", "xi", "epsilon"]
        )

    def forward(
        self, imgs: torch.Tensor, logits: torch.Tensor, model: torch.nn.Module
    ) -> torch.Tensor:
        """
        Arguments:
            imgs: The input to the model
            logits: The model's logits computed from ```imgs```
            model: The aforementioned model
        """
        logits = logits.detach()
        model.apply(c_f.set_layers_mode("eval", c_f.batchnorm_types()))
        perturbation = self.get_perturbation(imgs, logits, model)
        new_logits = model(imgs + perturbation)

        preds = F.softmax(logits, dim=1)
        new_preds = F.log_softmax(new_logits, dim=1)
        model.apply(c_f.set_layers_mode("train", c_f.batchnorm_types()))
        return self.kl_div(new_preds, preds)

    def get_perturbation(self, imgs, original_logits, model):
        noise = torch.randn(*imgs.shape, device=original_logits.device)
        original_preds = F.softmax(original_logits, dim=1)

        for _ in range(self.num_power_iterations):
            noise.requires_grad = True
            noise = self.xi * get_normalized_noise(noise)
            noise.retain_grad()
            new_preds = F.log_softmax(model(imgs + noise), dim=1)
            dist = self.kl_div(new_preds, original_preds)
            dist.backward(retain_graph=True)
            noise = noise.grad.detach()
            model.zero_grad()

        return self.epsilon * get_normalized_noise(noise)

    def extra_repr(self):
        """"""
        return c_f.extra_repr(self, ["num_power_iterations", "xi", "epsilon"])

__init__(num_power_iterations=1, xi=1e-06, epsilon=8.0)

Parameters:

Name Type Description Default
num_power_iterations int

The number of iterations for computing the approximation of the adversarial perturbation.

1
xi float

The L2 norm of the the generated noise which is used in the process of creating the perturbation.

1e-06
epsilon float

The L2 norm of the generated perturbation.

8.0
Source code in pytorch_adapt\layers\vat_loss.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __init__(
    self, num_power_iterations: int = 1, xi: float = 1e-6, epsilon: float = 8.0
):
    """
    Arguments:
        num_power_iterations: The number of iterations for
            computing the approximation of the adversarial perturbation.
        xi: The L2 norm of the the generated noise which is used
            in the process of creating the perturbation.
        epsilon: The L2 norm of the generated perturbation.
    """
    super().__init__()
    self.num_power_iterations = num_power_iterations
    self.xi = xi
    self.epsilon = epsilon
    self.kl_div = torch.nn.KLDivLoss(reduction="batchmean")
    pml_cf.add_to_recordable_attributes(
        self, list_of_names=["num_power_iterations", "xi", "epsilon"]
    )

forward(imgs, logits, model)

Parameters:

Name Type Description Default
imgs torch.Tensor

The input to the model

required
logits torch.Tensor

The model's logits computed from imgs

required
model torch.nn.Module

The aforementioned model

required
Source code in pytorch_adapt\layers\vat_loss.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def forward(
    self, imgs: torch.Tensor, logits: torch.Tensor, model: torch.nn.Module
) -> torch.Tensor:
    """
    Arguments:
        imgs: The input to the model
        logits: The model's logits computed from ```imgs```
        model: The aforementioned model
    """
    logits = logits.detach()
    model.apply(c_f.set_layers_mode("eval", c_f.batchnorm_types()))
    perturbation = self.get_perturbation(imgs, logits, model)
    new_logits = model(imgs + perturbation)

    preds = F.softmax(logits, dim=1)
    new_preds = F.log_softmax(new_logits, dim=1)
    model.apply(c_f.set_layers_mode("train", c_f.batchnorm_types()))
    return self.kl_div(new_preds, preds)