Skip to content

coral_loss

CORALLoss

Bases: torch.nn.Module

Implementation of Deep CORAL: Correlation Alignment for Deep Domain Adaptation

Source code in pytorch_adapt\layers\coral_loss.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class CORALLoss(torch.nn.Module):
    """
    Implementation of [Deep CORAL:
    Correlation Alignment for
    Deep Domain Adaptation](https://arxiv.org/abs/1607.01719)
    """

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        """
        Arguments:
            x: features from one domain
            y: features from the other domain
        """
        embedding_size = x.shape[1]
        cx = covariance(x)
        cy = covariance(y)
        squared_fro_norm = torch.linalg.norm(cx - cy, ord="fro") ** 2
        return squared_fro_norm / (4 * (embedding_size**2))

forward(x, y)

Parameters:

Name Type Description Default
x torch.Tensor

features from one domain

required
y torch.Tensor

features from the other domain

required
Source code in pytorch_adapt\layers\coral_loss.py
19
20
21
22
23
24
25
26
27
28
29
def forward(self, x: torch.Tensor, y: torch.Tensor):
    """
    Arguments:
        x: features from one domain
        y: features from the other domain
    """
    embedding_size = x.shape[1]
    cx = covariance(x)
    cy = covariance(y)
    squared_fro_norm = torch.linalg.norm(cx - cy, ord="fro") ** 2
    return squared_fro_norm / (4 * (embedding_size**2))