sliced_wasserstein
SlicedWasserstein
¶
Bases: torch.nn.Module
Implementation of the loss used in Sliced Wasserstein Discrepancy for Unsupervised Domain Adaptation
Source code in pytorch_adapt\layers\sliced_wasserstein.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
|
__init__(m=128)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
m |
int
|
The dimensionality to project to. |
128
|
Source code in pytorch_adapt\layers\sliced_wasserstein.py
12 13 14 15 16 17 18 |
|
forward(x, y)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
a batch of class predictions |
required |
y |
torch.Tensor
|
the other batch of class predictions |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
The discrepancy between the two batches of class predictions. |
Source code in pytorch_adapt\layers\sliced_wasserstein.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
|