Skip to content

neighborhood_aggregation

NeighborhoodAggregation

Bases: torch.nn.Module

Implementation of the pseudo labeling step in Domain Adaptation with Auxiliary Target Domain-Oriented Classifier.

Source code in pytorch_adapt\layers\neighborhood_aggregation.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class NeighborhoodAggregation(torch.nn.Module):
    """
    Implementation of the pseudo labeling step in
    [Domain Adaptation with Auxiliary Target Domain-Oriented Classifier](https://arxiv.org/abs/2007.04171).
    """

    def __init__(
        self,
        dataset_size: int,
        feature_dim: int,
        num_classes: int,
        k: int = 5,
        T: float = 0.5,
    ):
        """
        Arguments:
            dataset_size: The number of samples in the target dataset.
            feature_dim: The feature dimensionality, i.e at each iteration
                the features should be size ```(N, D)``` where N is batch size and
                D is ```feature_dim```.
            num_classes: The number of class labels in the target dataset.
            k: The number of nearest neighbors used to determine each
                sample's pseudolabel
            T: The softmax temperature used when storing predictions in memory.
        """

        super().__init__()
        self.register_buffer(
            "feat_memory", F.normalize(torch.rand(dataset_size, feature_dim))
        )
        self.register_buffer(
            "pred_memory", torch.ones(dataset_size, num_classes) / num_classes
        )
        self.k = k
        self.T = T

    def forward(
        self,
        features: torch.Tensor,
        logits: torch.Tensor = None,
        update: bool = False,
        idx: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Arguments:
            features: The features to compute pseudolabels for.
            logits: The logits from which predictions will be computed and
                stored in memory. Required if ```update = True```
            update: If True, the current batch of predictions is
                added to the memory bank.
            idx: A tensor containing the dataset indices that
                produced each row of ```features```.
        """
        # move to device if necessary
        self.feat_memory = pml_cf.to_device(self.feat_memory, features)
        self.pred_memory = pml_cf.to_device(self.pred_memory, features)
        with torch.no_grad():
            features = F.normalize(features)
            pseudo_labels, mean_preds = self.get_pseudo_labels(features, idx)
            if update:
                self.update_memory(features, logits, idx)
        return pseudo_labels, mean_preds

    def get_pseudo_labels(self, normalized_features, idx):
        dis = torch.mm(normalized_features, self.feat_memory.t())
        # set self-comparisons to min similarity
        for di in range(dis.size(0)):
            dis[di, idx[di]] = torch.min(dis)
        _, indices = torch.topk(dis, k=self.k, dim=1)
        preds = torch.mean(self.pred_memory[indices], dim=1)
        pseudo_labels = torch.argmax(preds, dim=1)
        return pseudo_labels, preds

    def update_memory(self, normalized_features, logits, idx):
        preds = F.softmax(logits, dim=1)
        p = 1.0 / self.T
        preds = (preds**p) / torch.sum(preds**p, dim=0)
        self.feat_memory[idx] = normalized_features
        self.pred_memory[idx] = preds

    def extra_repr(self):
        """"""
        return c_f.extra_repr(self, ["k", "T"])

__init__(dataset_size, feature_dim, num_classes, k=5, T=0.5)

Parameters:

Name Type Description Default
dataset_size int

The number of samples in the target dataset.

required
feature_dim int

The feature dimensionality, i.e at each iteration the features should be size (N, D) where N is batch size and D is feature_dim.

required
num_classes int

The number of class labels in the target dataset.

required
k int

The number of nearest neighbors used to determine each sample's pseudolabel

5
T float

The softmax temperature used when storing predictions in memory.

0.5
Source code in pytorch_adapt\layers\neighborhood_aggregation.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    dataset_size: int,
    feature_dim: int,
    num_classes: int,
    k: int = 5,
    T: float = 0.5,
):
    """
    Arguments:
        dataset_size: The number of samples in the target dataset.
        feature_dim: The feature dimensionality, i.e at each iteration
            the features should be size ```(N, D)``` where N is batch size and
            D is ```feature_dim```.
        num_classes: The number of class labels in the target dataset.
        k: The number of nearest neighbors used to determine each
            sample's pseudolabel
        T: The softmax temperature used when storing predictions in memory.
    """

    super().__init__()
    self.register_buffer(
        "feat_memory", F.normalize(torch.rand(dataset_size, feature_dim))
    )
    self.register_buffer(
        "pred_memory", torch.ones(dataset_size, num_classes) / num_classes
    )
    self.k = k
    self.T = T

forward(features, logits=None, update=False, idx=None)

Parameters:

Name Type Description Default
features torch.Tensor

The features to compute pseudolabels for.

required
logits torch.Tensor

The logits from which predictions will be computed and stored in memory. Required if update = True

None
update bool

If True, the current batch of predictions is added to the memory bank.

False
idx torch.Tensor

A tensor containing the dataset indices that produced each row of features.

None
Source code in pytorch_adapt\layers\neighborhood_aggregation.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def forward(
    self,
    features: torch.Tensor,
    logits: torch.Tensor = None,
    update: bool = False,
    idx: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        features: The features to compute pseudolabels for.
        logits: The logits from which predictions will be computed and
            stored in memory. Required if ```update = True```
        update: If True, the current batch of predictions is
            added to the memory bank.
        idx: A tensor containing the dataset indices that
            produced each row of ```features```.
    """
    # move to device if necessary
    self.feat_memory = pml_cf.to_device(self.feat_memory, features)
    self.pred_memory = pml_cf.to_device(self.pred_memory, features)
    with torch.no_grad():
        features = F.normalize(features)
        pseudo_labels, mean_preds = self.get_pseudo_labels(features, idx)
        if update:
            self.update_memory(features, logits, idx)
    return pseudo_labels, mean_preds