Skip to content

base_validator

BaseValidator

Bases: ABC

The parent class of all validators.

The main purpose of validators is to give an estimate of target domain accuracy, usually without having access to class labels.

Source code in pytorch_adapt\validators\base_validator.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class BaseValidator(ABC):
    """
    The parent class of all validators.

    The main purpose of validators is to give an estimate
    of target domain accuracy, usually without having access to
    class labels.
    """

    def __init__(self, key_map: Dict[str, str] = None):
        """
        Arguments:
            key_map: A mapping from ```<new_split_names>``` to
                ```<original_split_names>```. For example,
                [```AccuracyValidator```][pytorch_adapt.validators.AccuracyValidator]
                expects ```src_val``` by default. When used with one of the
                [```frameworks```](../frameworks/index.md), this default
                indicates that data related to the ```src_val``` split should be retrieved.
                If you instead want to compute accuracy for the ```src_train``` split,
                you would set the ```key_map``` to ```{"src_train": "src_val"}```.
        """
        self.key_map = c_f.default(key_map, {})

    def _required_data(self):
        args = inspect.getfullargspec(self.compute_score).args
        args.remove("self")
        return args

    @property
    def required_data(self) -> List[str]:
        """
        Returns:
            A list of dataset split names.
        """
        output = set(self._required_data()) - set(self.key_map.values())
        output = list(output)
        for k, v in self.key_map.items():
            output.append(k)
        return output

    @abstractmethod
    def compute_score(self):
        pass

    def __call__(self, **kwargs) -> float:
        """
        Arguments:
            **kwargs: A mapping from dataset split name to
                dictionaries containing:

                - ```"features"```
                - ```"logits"```
                - ```"preds"```
                - ```"domain"```
                - ```"labels"``` (if available)
        Returns:
            The validation score.
        """
        kwargs = self.kwargs_check(kwargs)
        return self.compute_score(**kwargs)

    def kwargs_check(self, kwargs):
        if kwargs.keys() != set(self.required_data):
            raise ValueError(
                f"Input to compute_score has keys = {kwargs.keys()} but should have keys {self.required_data}"
            )
        return c_f.map_keys(kwargs, self.key_map)

    def __repr__(self):
        return c_f.nice_repr(self, self.extra_repr(), {})

    def extra_repr(self):
        return c_f.extra_repr(self, ["required_data"])

__call__(**kwargs)

Parameters:

Name Type Description Default
**kwargs

A mapping from dataset split name to dictionaries containing:

  • "features"
  • "logits"
  • "preds"
  • "domain"
  • "labels" (if available)
{}

Returns:

Type Description
float

The validation score.

Source code in pytorch_adapt\validators\base_validator.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __call__(self, **kwargs) -> float:
    """
    Arguments:
        **kwargs: A mapping from dataset split name to
            dictionaries containing:

            - ```"features"```
            - ```"logits"```
            - ```"preds"```
            - ```"domain"```
            - ```"labels"``` (if available)
    Returns:
        The validation score.
    """
    kwargs = self.kwargs_check(kwargs)
    return self.compute_score(**kwargs)

__init__(key_map=None)

Parameters:

Name Type Description Default
key_map Dict[str, str]

A mapping from <new_split_names> to <original_split_names>. For example, AccuracyValidator expects src_val by default. When used with one of the frameworks, this default indicates that data related to the src_val split should be retrieved. If you instead want to compute accuracy for the src_train split, you would set the key_map to {"src_train": "src_val"}.

None
Source code in pytorch_adapt\validators\base_validator.py
17
18
19
20
21
22
23
24
25
26
27
28
29
def __init__(self, key_map: Dict[str, str] = None):
    """
    Arguments:
        key_map: A mapping from ```<new_split_names>``` to
            ```<original_split_names>```. For example,
            [```AccuracyValidator```][pytorch_adapt.validators.AccuracyValidator]
            expects ```src_val``` by default. When used with one of the
            [```frameworks```](../frameworks/index.md), this default
            indicates that data related to the ```src_val``` split should be retrieved.
            If you instead want to compute accuracy for the ```src_train``` split,
            you would set the ```key_map``` to ```{"src_train": "src_val"}```.
    """
    self.key_map = c_f.default(key_map, {})

required_data() property

Returns:

Type Description
List[str]

A list of dataset split names.

Source code in pytorch_adapt\validators\base_validator.py
36
37
38
39
40
41
42
43
44
45
46
@property
def required_data(self) -> List[str]:
    """
    Returns:
        A list of dataset split names.
    """
    output = set(self._required_data()) - set(self.key_map.values())
    output = list(output)
    for k, v in self.key_map.items():
        output.append(k)
    return output