Bases: DomainDataset
Wrap your target dataset with this.
If supervised = True, the wrapped dataset's __getitem__
must return a tuple of (data, label).
Otherwise it can return either (data, label) or data.
Source code in pytorch_adapt\datasets\target_dataset.py
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63 | class TargetDataset(DomainDataset):
"""
Wrap your target dataset with this.
If ```supervised = True```, the wrapped dataset's ```__getitem__```
must return a tuple of ```(data, label)```.
Otherwise it can return either ```(data, label)``` or ```data```.
"""
def __init__(self, dataset: Dataset, domain: int = 1, supervised: bool = False):
"""
Arguments:
dataset: The dataset to wrap
domain: An integer representing the domain.
supervised: A boolean for if the target dataset should return labels.
"""
super().__init__(dataset, domain)
self.supervised = supervised
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Returns:
A dictionary with keys
- "target_imgs" (the data)
- "target_domain" (the integer representing the domain)
- "target_sample_idx" (idx)
If ```supervised = True``` it returns an extra key
- "target_labels" (the class label)
"""
has_labels = False
img = self.dataset[idx]
if isinstance(img, (list, tuple)):
has_labels = True
img, labels = img
if self.supervised and not has_labels:
raise ValueError(
"if TargetDataset is instantiated with supervised=True, the wrapped dataset must include labels"
)
item = {
"target_imgs": img,
"target_domain": self.domain,
"target_sample_idx": idx,
}
if self.supervised:
item["target_labels"] = labels
return item
|
__getitem__(idx)
Returns:
| Type |
Description |
Dict[str, Any]
|
A dictionary with keys
If supervised = True it returns an extra key
- "target_labels" (the class label)
|
Source code in pytorch_adapt\datasets\target_dataset.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63 | def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Returns:
A dictionary with keys
- "target_imgs" (the data)
- "target_domain" (the integer representing the domain)
- "target_sample_idx" (idx)
If ```supervised = True``` it returns an extra key
- "target_labels" (the class label)
"""
has_labels = False
img = self.dataset[idx]
if isinstance(img, (list, tuple)):
has_labels = True
img, labels = img
if self.supervised and not has_labels:
raise ValueError(
"if TargetDataset is instantiated with supervised=True, the wrapped dataset must include labels"
)
item = {
"target_imgs": img,
"target_domain": self.domain,
"target_sample_idx": idx,
}
if self.supervised:
item["target_labels"] = labels
return item
|
__init__(dataset, domain=1, supervised=False)
Parameters:
| Name |
Type |
Description |
Default |
dataset |
Dataset
|
The dataset to wrap |
required
|
domain |
int
|
An integer representing the domain. |
1
|
supervised |
bool
|
A boolean for if the target dataset should return labels. |
False
|
Source code in pytorch_adapt\datasets\target_dataset.py
17
18
19
20
21
22
23
24
25 | def __init__(self, dataset: Dataset, domain: int = 1, supervised: bool = False):
"""
Arguments:
dataset: The dataset to wrap
domain: An integer representing the domain.
supervised: A boolean for if the target dataset should return labels.
"""
super().__init__(dataset, domain)
self.supervised = supervised
|