Skip to content

batch_spectral_loss

BatchSpectralLoss

Bases: torch.nn.Module

Implementation of the loss in Transferability vs. Discriminability: Batch Spectral Penalization for Adversarial Domain Adaptation. The loss is the sum of the squares of the first k singular values.

Source code in pytorch_adapt\layers\batch_spectral_loss.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class BatchSpectralLoss(torch.nn.Module):
    """
    Implementation of the loss in
    [Transferability vs. Discriminability: Batch Spectral
    Penalization for Adversarial Domain Adaptation](http://proceedings.mlr.press/v97/chen19i.html).
    The loss is the sum of the squares of the first k singular values.
    """

    def __init__(self, k: int = 1):
        """
        Arguments:
            k: the number of singular values to include in the loss
        """
        super().__init__()
        self.k = k

    def forward(self, x):
        """"""
        return batch_spectral_loss(x, self.k)

    def extra_repr(self):
        """"""
        return c_f.extra_repr(self, ["k"])

__init__(k=1)

Parameters:

Name Type Description Default
k int

the number of singular values to include in the loss

1
Source code in pytorch_adapt\layers\batch_spectral_loss.py
19
20
21
22
23
24
25
def __init__(self, k: int = 1):
    """
    Arguments:
        k: the number of singular values to include in the loss
    """
    super().__init__()
    self.k = k