Bases: ABC
The parent class of all validators.
The main purpose of validators is to give an estimate
of target domain accuracy, usually without having access to
class labels.
Source code in pytorch_adapt\validators\base_validator.py
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80 | class BaseValidator(ABC):
"""
The parent class of all validators.
The main purpose of validators is to give an estimate
of target domain accuracy, usually without having access to
class labels.
"""
def __init__(self, key_map: Dict[str, str] = None):
"""
Arguments:
key_map: A mapping from ```<new_split_names>``` to
```<original_split_names>```. For example,
[```AccuracyValidator```][pytorch_adapt.validators.AccuracyValidator]
expects ```src_val``` by default. When used with one of the
[```frameworks```](../frameworks/index.md), this default
indicates that data related to the ```src_val``` split should be retrieved.
If you instead want to compute accuracy for the ```src_train``` split,
you would set the ```key_map``` to ```{"src_train": "src_val"}```.
"""
self.key_map = c_f.default(key_map, {})
def _required_data(self):
args = inspect.getfullargspec(self.compute_score).args
args.remove("self")
return args
@property
def required_data(self) -> List[str]:
"""
Returns:
A list of dataset split names.
"""
output = set(self._required_data()) - set(self.key_map.values())
output = list(output)
for k, v in self.key_map.items():
output.append(k)
return output
@abstractmethod
def compute_score(self):
pass
def __call__(self, **kwargs) -> float:
"""
Arguments:
**kwargs: A mapping from dataset split name to
dictionaries containing:
- ```"features"```
- ```"logits"```
- ```"preds"```
- ```"domain"```
- ```"labels"``` (if available)
Returns:
The validation score.
"""
kwargs = self.kwargs_check(kwargs)
return self.compute_score(**kwargs)
def kwargs_check(self, kwargs):
if kwargs.keys() != set(self.required_data):
raise ValueError(
f"Input to compute_score has keys = {kwargs.keys()} but should have keys {self.required_data}"
)
return c_f.map_keys(kwargs, self.key_map)
def __repr__(self):
return c_f.nice_repr(self, self.extra_repr(), {})
def extra_repr(self):
return c_f.extra_repr(self, ["required_data"])
|
__call__(**kwargs)
Parameters:
Name |
Type |
Description |
Default |
**kwargs |
|
A mapping from dataset split name to
dictionaries containing:
"features"
"logits"
"preds"
"domain"
"labels" (if available)
|
{}
|
Returns:
Type |
Description |
float
|
The validation score. |
Source code in pytorch_adapt\validators\base_validator.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67 | def __call__(self, **kwargs) -> float:
"""
Arguments:
**kwargs: A mapping from dataset split name to
dictionaries containing:
- ```"features"```
- ```"logits"```
- ```"preds"```
- ```"domain"```
- ```"labels"``` (if available)
Returns:
The validation score.
"""
kwargs = self.kwargs_check(kwargs)
return self.compute_score(**kwargs)
|
__init__(key_map=None)
Parameters:
Name |
Type |
Description |
Default |
key_map |
Dict[str, str]
|
A mapping from <new_split_names> to
<original_split_names> . For example,
AccuracyValidator
expects src_val by default. When used with one of the
frameworks , this default
indicates that data related to the src_val split should be retrieved.
If you instead want to compute accuracy for the src_train split,
you would set the key_map to {"src_train": "src_val"} . |
None
|
Source code in pytorch_adapt\validators\base_validator.py
17
18
19
20
21
22
23
24
25
26
27
28
29 | def __init__(self, key_map: Dict[str, str] = None):
"""
Arguments:
key_map: A mapping from ```<new_split_names>``` to
```<original_split_names>```. For example,
[```AccuracyValidator```][pytorch_adapt.validators.AccuracyValidator]
expects ```src_val``` by default. When used with one of the
[```frameworks```](../frameworks/index.md), this default
indicates that data related to the ```src_val``` split should be retrieved.
If you instead want to compute accuracy for the ```src_train``` split,
you would set the ```key_map``` to ```{"src_train": "src_val"}```.
"""
self.key_map = c_f.default(key_map, {})
|
required_data()
property
Returns:
Type |
Description |
List[str]
|
A list of dataset split names. |
Source code in pytorch_adapt\validators\base_validator.py
36
37
38
39
40
41
42
43
44
45
46 | @property
def required_data(self) -> List[str]:
"""
Returns:
A list of dataset split names.
"""
output = set(self._required_data()) - set(self.key_map.values())
output = list(output)
for k, v in self.key_map.items():
output.append(k)
return output
|