Skip to content

reverse_validator

ReverseValidator

Reverse validation consists of three steps.

  1. Train a model on the labeled source and unlabeled target

  2. Use the trained model to create pseudolabels for the target dataset.

  3. Train a new model on the labeled target and "unlabeled" source.

The final score is the accuracy of the model from step 3.

Source code in pytorch_adapt\meta_validators\reverse_validator.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class ReverseValidator:
    """
    Reverse validation consists of three steps.

    1. Train a model on the labeled source and unlabeled target

    2. Use the trained model to create pseudolabels for the target dataset.

    3. Train a new model on the labeled target and "unlabeled" source.

    The final score is the accuracy of the model from step 3.
    """

    def __init__(self):
        self.pseudo_train = None
        self.pseudo_val = None

    def run(
        self,
        forward_adapter,
        reverse_adapter,
        forward_kwargs,
        reverse_kwargs,
        pl_dataloader_creator=None,
    ) -> Tuple[float, int]:
        """
        Arguments:
            forward_adapter: the framework-wrapped adapter for step 1.
            reverse_adapter: the framework-wrapped adapter for step 3.
            forward_kwargs: a dict of keyword arguments to be passed to forward_adapter.run()
            reverse_kwargs: a dict of keyword arguments to be passed to reverse_adapter.run()
            pl_dataloader_creator: An optional DataloaderCreator
                for obtaining pseudolabels in step 2.
        Returns:
            the best score and best epoch of the reverse model
        """

        if "datasets" in reverse_kwargs:
            raise KeyError(
                "'datasets' should not be in reverse_kwargs because the reverse datasets will be pseudo labeled."
            )
        if not reverse_adapter.validator:
            raise KeyError("reverse_adapter must include 'validator'")

        forward_adapter.run(**forward_kwargs)
        if all(getattr(forward_adapter, x) for x in ["validator", "checkpoint_fn"]):
            forward_adapter.checkpoint_fn.load_best_checkpoint(
                {"models": forward_adapter.adapter.models},
            )

        datasets = forward_kwargs["datasets"]
        pl_dataloader_creator = c_f.default(
            pl_dataloader_creator, DataloaderCreator, {"all_val": True}
        )

        d = {}
        d["src_train"] = get_pseudo_labeled_dataset(
            forward_adapter, datasets, "target_train", pl_dataloader_creator
        )
        d["src_val"] = get_pseudo_labeled_dataset(
            forward_adapter, datasets, "target_val", pl_dataloader_creator
        )
        d["target_train"] = TargetDataset(datasets["src_train"].dataset)
        d["target_val"] = TargetDataset(datasets["src_val"].dataset)
        d["train"] = CombinedSourceAndTargetDataset(d["src_train"], d["target_train"])

        self.pseudo_train = d["src_train"]
        self.pseudo_val = d["src_val"]

        reverse_kwargs["datasets"] = d
        return reverse_adapter.run(**reverse_kwargs)

run(forward_adapter, reverse_adapter, forward_kwargs, reverse_kwargs, pl_dataloader_creator=None)

Parameters:

Name Type Description Default
forward_adapter

the framework-wrapped adapter for step 1.

required
reverse_adapter

the framework-wrapped adapter for step 3.

required
forward_kwargs

a dict of keyword arguments to be passed to forward_adapter.run()

required
reverse_kwargs

a dict of keyword arguments to be passed to reverse_adapter.run()

required
pl_dataloader_creator

An optional DataloaderCreator for obtaining pseudolabels in step 2.

None

Returns:

Type Description
Tuple[float, int]

the best score and best epoch of the reverse model

Source code in pytorch_adapt\meta_validators\reverse_validator.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def run(
    self,
    forward_adapter,
    reverse_adapter,
    forward_kwargs,
    reverse_kwargs,
    pl_dataloader_creator=None,
) -> Tuple[float, int]:
    """
    Arguments:
        forward_adapter: the framework-wrapped adapter for step 1.
        reverse_adapter: the framework-wrapped adapter for step 3.
        forward_kwargs: a dict of keyword arguments to be passed to forward_adapter.run()
        reverse_kwargs: a dict of keyword arguments to be passed to reverse_adapter.run()
        pl_dataloader_creator: An optional DataloaderCreator
            for obtaining pseudolabels in step 2.
    Returns:
        the best score and best epoch of the reverse model
    """

    if "datasets" in reverse_kwargs:
        raise KeyError(
            "'datasets' should not be in reverse_kwargs because the reverse datasets will be pseudo labeled."
        )
    if not reverse_adapter.validator:
        raise KeyError("reverse_adapter must include 'validator'")

    forward_adapter.run(**forward_kwargs)
    if all(getattr(forward_adapter, x) for x in ["validator", "checkpoint_fn"]):
        forward_adapter.checkpoint_fn.load_best_checkpoint(
            {"models": forward_adapter.adapter.models},
        )

    datasets = forward_kwargs["datasets"]
    pl_dataloader_creator = c_f.default(
        pl_dataloader_creator, DataloaderCreator, {"all_val": True}
    )

    d = {}
    d["src_train"] = get_pseudo_labeled_dataset(
        forward_adapter, datasets, "target_train", pl_dataloader_creator
    )
    d["src_val"] = get_pseudo_labeled_dataset(
        forward_adapter, datasets, "target_val", pl_dataloader_creator
    )
    d["target_train"] = TargetDataset(datasets["src_train"].dataset)
    d["target_val"] = TargetDataset(datasets["src_val"].dataset)
    d["train"] = CombinedSourceAndTargetDataset(d["src_train"], d["target_train"])

    self.pseudo_train = d["src_train"]
    self.pseudo_val = d["src_val"]

    reverse_kwargs["datasets"] = d
    return reverse_adapter.run(**reverse_kwargs)