Skip to content

deep_embedded_validator

DeepEmbeddedValidator

Bases: BaseValidator

Implementation of Towards Accurate Model Selection in Deep Unsupervised Domain Adaptation

Source code in pytorch_adapt\validators\deep_embedded_validator.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class DeepEmbeddedValidator(BaseValidator):
    """
    Implementation of
    [Towards Accurate Model Selection in Deep Unsupervised Domain Adaptation](http://proceedings.mlr.press/v97/you19a.html)
    """

    def __init__(
        self,
        temp_folder,
        layer="features",
        num_workers=0,
        batch_size=32,
        error_fn=None,
        error_layer="logits",
        normalization=None,
        framework_fn=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.temp_folder = temp_folder
        self.layer = layer
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.error_fn = c_f.default(
            error_fn, torch.nn.CrossEntropyLoss(reduction="none")
        )
        self.error_layer = error_layer
        check_normalization(normalization)
        self.normalization = normalization
        self.framework_fn = c_f.default(framework_fn, default_framework_fn)
        self.D_accuracy_val = None
        self.D_accuracy_test = None
        self.mean_error = None
        self._DEV_recordable = ["D_accuracy_val", "D_accuracy_test", "mean_error"]
        pml_cf.add_to_recordable_attributes(self, list_of_names=self._DEV_recordable)

    def compute_score(self, src_train, src_val, target_train):
        init_logging_level = c_f.LOGGER.level
        c_f.LOGGER.setLevel(logging.WARNING)
        weights, self.D_accuracy_val, self.D_accuracy_test = get_weights(
            src_train[self.layer],
            src_val[self.layer],
            target_train[self.layer],
            self.num_workers,
            self.batch_size,
            self.temp_folder,
            self.framework_fn,
        )
        error_per_sample = self.error_fn(src_val[self.error_layer], src_val["labels"])
        output = get_dev_risk(weights, error_per_sample[:, None], self.normalization)
        self.mean_error = torch.mean(error_per_sample).item()
        c_f.LOGGER.setLevel(init_logging_level)
        return -output

    def extra_repr(self):
        x = super().extra_repr()
        x += f"\n{c_f.extra_repr(self, self._DEV_recordable)}"
        return x