Skip to content

snd_validator

SNDValidator

Bases: BaseValidator

Implementation of Tune it the Right Way: Unsupervised Validation of Domain Adaptation via Soft Neighborhood Density

Source code in pytorch_adapt\validators\snd_validator.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class SNDValidator(BaseValidator):
    """
    Implementation of
    [Tune it the Right Way: Unsupervised Validation of Domain Adaptation via Soft Neighborhood Density](https://arxiv.org/abs/2108.10860)
    """

    def __init__(self, layer="preds", T=0.05, batch_size=1024, **kwargs):
        super().__init__(**kwargs)
        self.layer = layer
        self.T = T
        self.entropy_fn = EntropyLoss(after_softmax=True, return_mean=False)
        self.dist_fn = BatchedDistance(CosineSimilarity(), batch_size=batch_size)

    def compute_score(self, target_train):
        features = target_train[self.layer]
        # all_entropies is modified via self.iter_fn
        all_entropies = []
        self.dist_fn.iter_fn = get_iter_fn(all_entropies, self.entropy_fn, self.T)
        self.dist_fn(features)
        all_entropies = torch.cat(all_entropies, dim=0)
        if len(all_entropies) != len(features):
            raise ValueError("all_entropies should have same length as input features")
        return torch.mean(all_entropies).item()