Skip to content

sliced_wasserstein

SlicedWasserstein

Bases: torch.nn.Module

Implementation of the loss used in Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation

Source code in pytorch_adapt\layers\sliced_wasserstein.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class SlicedWasserstein(torch.nn.Module):
    """
    Implementation of the loss used in
    [Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation](https://arxiv.org/abs/1903.04064)
    """

    def __init__(self, m: int = 128):
        """
        Arguments:
            m: The dimensionality to project to.
        """
        super().__init__()
        self.m = 128

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: a batch of class predictions
            y: the other batch of class predictions
        Returns:
            The discrepancy between the two batches of class predictions.
        """
        d = x.shape[1]
        proj = torch.randn(d, self.m, device=x.device)
        proj = torch.nn.functional.normalize(proj, dim=0)
        x = torch.matmul(x, proj)
        y = torch.matmul(y, proj)
        x, _ = torch.sort(x, dim=0)
        y, _ = torch.sort(y, dim=0)
        return torch.mean((x - y) ** 2)

    def extra_repr(self):
        """"""
        return c_f.extra_repr(self, ["m"])

__init__(m=128)

Parameters:

Name Type Description Default
m int

The dimensionality to project to.

128
Source code in pytorch_adapt\layers\sliced_wasserstein.py
12
13
14
15
16
17
18
def __init__(self, m: int = 128):
    """
    Arguments:
        m: The dimensionality to project to.
    """
    super().__init__()
    self.m = 128

forward(x, y)

Parameters:

Name Type Description Default
x torch.Tensor

a batch of class predictions

required
y torch.Tensor

the other batch of class predictions

required

Returns:

Type Description
torch.Tensor

The discrepancy between the two batches of class predictions.

Source code in pytorch_adapt\layers\sliced_wasserstein.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Arguments:
        x: a batch of class predictions
        y: the other batch of class predictions
    Returns:
        The discrepancy between the two batches of class predictions.
    """
    d = x.shape[1]
    proj = torch.randn(d, self.m, device=x.device)
    proj = torch.nn.functional.normalize(proj, dim=0)
    x = torch.matmul(x, proj)
    y = torch.matmul(y, proj)
    x, _ = torch.sort(x, dim=0)
    y, _ = torch.sort(y, dim=0)
    return torch.mean((x - y) ** 2)