mmd_loss
MMDBatchedLoss
¶
Bases: MMDLoss
Source code in pytorch_adapt\layers\mmd_loss.py
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
forward(x, y)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
features from one domain. |
required |
y |
torch.Tensor
|
features from the other domain. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
MMD |
Source code in pytorch_adapt\layers\mmd_loss.py
99 100 101 102 103 104 105 106 107 108 109 110 |
|
MMDLoss
¶
Bases: torch.nn.Module
Implementation of
Source code in pytorch_adapt\layers\mmd_loss.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
|
__init__(kernel_scales=1, mmd_type='linear', dist_func=None, bandwidth=None)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
kernel_scales |
Union[float, torch.Tensor]
|
The kernel bandwidth is scaled by this amount. If a tensor, then multiple kernel bandwidths are used. |
1
|
mmd_type |
str
|
'linear' or 'quadratic'. 'linear' uses the linear estimate of MK-MMD. |
'linear'
|
Source code in pytorch_adapt\layers\mmd_loss.py
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
|
forward(x, y)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Union[torch.Tensor, List[torch.Tensor]]
|
features or a list of features from one domain. |
required |
y |
Union[torch.Tensor, List[torch.Tensor]]
|
features or a list of features from the other domain. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
MMD if the inputs are tensors, and Joint MMD (JMMD) if the inputs are lists of tensors. |
Source code in pytorch_adapt\layers\mmd_loss.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
|