Skip to content

uniform_distribution_loss

UniformDistributionLoss

Bases: torch.nn.Module

Implementation of the confusion loss from Simultaneous Deep Transfer Across Domains and Tasks.

Source code in pytorch_adapt\layers\uniform_distribution_loss.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
class UniformDistributionLoss(torch.nn.Module):
    """
    Implementation of the confusion loss from
    [Simultaneous Deep Transfer Across Domains and Tasks](https://arxiv.org/abs/1510.02192).
    """

    # *args to make it work as a drop in replacement for CrossEntropyLoss
    def forward(self, x, *args):
        """"""
        probs = F.log_softmax(x, dim=1)
        avg_probs = torch.mean(probs, dim=1)
        return -torch.mean(avg_probs)