Skip to content

concat_softmax

ConcatSoftmax

Bases: torch.nn.Module

Applies softmax to the concatenation of a list of tensors.

Source code in pytorch_adapt\layers\concat_softmax.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class ConcatSoftmax(torch.nn.Module):
    """
    Applies softmax to the concatenation of a list of tensors.
    """

    def __init__(self, dim: int = 1):
        """
        Arguments:
            dim: a dimension along which softmax will be computed
        """
        super().__init__()
        self.dim = dim

    def forward(self, *x: torch.Tensor):
        """
        Arguments:
            *x: A sequence of tensors to be concatenated
        """
        all_logits = torch.cat(x, dim=self.dim)
        return torch.nn.functional.softmax(all_logits, dim=self.dim)

    def extra_repr(self):
        """"""
        return c_f.extra_repr(self, ["dim"])

__init__(dim=1)

Parameters:

Name Type Description Default
dim int

a dimension along which softmax will be computed

1
Source code in pytorch_adapt\layers\concat_softmax.py
11
12
13
14
15
16
17
def __init__(self, dim: int = 1):
    """
    Arguments:
        dim: a dimension along which softmax will be computed
    """
    super().__init__()
    self.dim = dim

forward(*x)

Parameters:

Name Type Description Default
*x torch.Tensor

A sequence of tensors to be concatenated

()
Source code in pytorch_adapt\layers\concat_softmax.py
19
20
21
22
23
24
25
def forward(self, *x: torch.Tensor):
    """
    Arguments:
        *x: A sequence of tensors to be concatenated
    """
    all_logits = torch.cat(x, dim=self.dim)
    return torch.nn.functional.softmax(all_logits, dim=self.dim)