Reverse validation consists of three steps.
-
Train a model on the labeled source and unlabeled target
-
Use the trained model to create pseudolabels for the target dataset.
-
Train a new model on the labeled target and "unlabeled" source.
The final score is the accuracy of the model from step 3.
Source code in pytorch_adapt\meta_validators\reverse_validator.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99 | class ReverseValidator:
"""
Reverse validation consists of three steps.
1. Train a model on the labeled source and unlabeled target
2. Use the trained model to create pseudolabels for the target dataset.
3. Train a new model on the labeled target and "unlabeled" source.
The final score is the accuracy of the model from step 3.
"""
def __init__(self):
self.pseudo_train = None
self.pseudo_val = None
def run(
self,
forward_adapter,
reverse_adapter,
forward_kwargs,
reverse_kwargs,
pl_dataloader_creator=None,
) -> Tuple[float, int]:
"""
Arguments:
forward_adapter: the framework-wrapped adapter for step 1.
reverse_adapter: the framework-wrapped adapter for step 3.
forward_kwargs: a dict of keyword arguments to be passed to forward_adapter.run()
reverse_kwargs: a dict of keyword arguments to be passed to reverse_adapter.run()
pl_dataloader_creator: An optional DataloaderCreator
for obtaining pseudolabels in step 2.
Returns:
the best score and best epoch of the reverse model
"""
if "datasets" in reverse_kwargs:
raise KeyError(
"'datasets' should not be in reverse_kwargs because the reverse datasets will be pseudo labeled."
)
if not reverse_adapter.validator:
raise KeyError("reverse_adapter must include 'validator'")
forward_adapter.run(**forward_kwargs)
if all(getattr(forward_adapter, x) for x in ["validator", "checkpoint_fn"]):
forward_adapter.checkpoint_fn.load_best_checkpoint(
{"models": forward_adapter.adapter.models},
)
datasets = forward_kwargs["datasets"]
pl_dataloader_creator = c_f.default(
pl_dataloader_creator, DataloaderCreator, {"all_val": True}
)
d = {}
d["src_train"] = get_pseudo_labeled_dataset(
forward_adapter, datasets, "target_train", pl_dataloader_creator
)
d["src_val"] = get_pseudo_labeled_dataset(
forward_adapter, datasets, "target_val", pl_dataloader_creator
)
d["target_train"] = TargetDataset(datasets["src_train"].dataset)
d["target_val"] = TargetDataset(datasets["src_val"].dataset)
d["train"] = CombinedSourceAndTargetDataset(d["src_train"], d["target_train"])
self.pseudo_train = d["src_train"]
self.pseudo_val = d["src_val"]
reverse_kwargs["datasets"] = d
return reverse_adapter.run(**reverse_kwargs)
|
Parameters:
Name |
Type |
Description |
Default |
forward_adapter |
|
the framework-wrapped adapter for step 1. |
required
|
reverse_adapter |
|
the framework-wrapped adapter for step 3. |
required
|
forward_kwargs |
|
a dict of keyword arguments to be passed to forward_adapter.run() |
required
|
reverse_kwargs |
|
a dict of keyword arguments to be passed to reverse_adapter.run() |
required
|
pl_dataloader_creator |
|
An optional DataloaderCreator
for obtaining pseudolabels in step 2. |
None
|
Returns:
Type |
Description |
Tuple[float, int]
|
the best score and best epoch of the reverse model |
Source code in pytorch_adapt\meta_validators\reverse_validator.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99 | def run(
self,
forward_adapter,
reverse_adapter,
forward_kwargs,
reverse_kwargs,
pl_dataloader_creator=None,
) -> Tuple[float, int]:
"""
Arguments:
forward_adapter: the framework-wrapped adapter for step 1.
reverse_adapter: the framework-wrapped adapter for step 3.
forward_kwargs: a dict of keyword arguments to be passed to forward_adapter.run()
reverse_kwargs: a dict of keyword arguments to be passed to reverse_adapter.run()
pl_dataloader_creator: An optional DataloaderCreator
for obtaining pseudolabels in step 2.
Returns:
the best score and best epoch of the reverse model
"""
if "datasets" in reverse_kwargs:
raise KeyError(
"'datasets' should not be in reverse_kwargs because the reverse datasets will be pseudo labeled."
)
if not reverse_adapter.validator:
raise KeyError("reverse_adapter must include 'validator'")
forward_adapter.run(**forward_kwargs)
if all(getattr(forward_adapter, x) for x in ["validator", "checkpoint_fn"]):
forward_adapter.checkpoint_fn.load_best_checkpoint(
{"models": forward_adapter.adapter.models},
)
datasets = forward_kwargs["datasets"]
pl_dataloader_creator = c_f.default(
pl_dataloader_creator, DataloaderCreator, {"all_val": True}
)
d = {}
d["src_train"] = get_pseudo_labeled_dataset(
forward_adapter, datasets, "target_train", pl_dataloader_creator
)
d["src_val"] = get_pseudo_labeled_dataset(
forward_adapter, datasets, "target_val", pl_dataloader_creator
)
d["target_train"] = TargetDataset(datasets["src_train"].dataset)
d["target_val"] = TargetDataset(datasets["src_val"].dataset)
d["train"] = CombinedSourceAndTargetDataset(d["src_train"], d["target_train"])
self.pseudo_train = d["src_train"]
self.pseudo_val = d["src_val"]
reverse_kwargs["datasets"] = d
return reverse_adapter.run(**reverse_kwargs)
|