class DeepEmbeddedValidator(BaseValidator):
"""
Implementation of
[Towards Accurate Model Selection in Deep Unsupervised Domain Adaptation](http://proceedings.mlr.press/v97/you19a.html)
"""
def __init__(
self,
temp_folder,
layer="features",
num_workers=0,
batch_size=32,
error_fn=None,
error_layer="logits",
normalization=None,
framework_fn=None,
**kwargs,
):
super().__init__(**kwargs)
self.temp_folder = temp_folder
self.layer = layer
self.num_workers = num_workers
self.batch_size = batch_size
self.error_fn = c_f.default(
error_fn, torch.nn.CrossEntropyLoss(reduction="none")
)
self.error_layer = error_layer
check_normalization(normalization)
self.normalization = normalization
self.framework_fn = c_f.default(framework_fn, default_framework_fn)
self.D_accuracy_val = None
self.D_accuracy_test = None
self.mean_error = None
self._DEV_recordable = ["D_accuracy_val", "D_accuracy_test", "mean_error"]
pml_cf.add_to_recordable_attributes(self, list_of_names=self._DEV_recordable)
def compute_score(self, src_train, src_val, target_train):
init_logging_level = c_f.LOGGER.level
c_f.LOGGER.setLevel(logging.WARNING)
weights, self.D_accuracy_val, self.D_accuracy_test = get_weights(
src_train[self.layer],
src_val[self.layer],
target_train[self.layer],
self.num_workers,
self.batch_size,
self.temp_folder,
self.framework_fn,
)
error_per_sample = self.error_fn(src_val[self.error_layer], src_val["labels"])
output = get_dev_risk(weights, error_per_sample[:, None], self.normalization)
self.mean_error = torch.mean(error_per_sample).item()
c_f.LOGGER.setLevel(init_logging_level)
return -output
def extra_repr(self):
x = super().extra_repr()
x += f"\n{c_f.extra_repr(self, self._DEV_recordable)}"
return x