Skip to content

target_dataset

TargetDataset

Bases: DomainDataset

Wrap your target dataset with this.

If supervised = True, the wrapped dataset's __getitem__ must return a tuple of (data, label). Otherwise it can return either (data, label) or data.

Source code in pytorch_adapt\datasets\target_dataset.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class TargetDataset(DomainDataset):
    """
    Wrap your target dataset with this.

    If ```supervised = True```, the wrapped dataset's ```__getitem__```
    must return a tuple of ```(data, label)```.
    Otherwise it can return either ```(data, label)``` or ```data```.
    """

    def __init__(self, dataset: Dataset, domain: int = 1, supervised: bool = False):
        """
        Arguments:
            dataset: The dataset to wrap
            domain: An integer representing the domain.
            supervised: A boolean for if the target dataset should return labels.
        """
        super().__init__(dataset, domain)
        self.supervised = supervised

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """
        Returns:
            A dictionary with keys

                - "target_imgs" (the data)

                - "target_domain" (the integer representing the domain)

                - "target_sample_idx" (idx)

                If ```supervised = True``` it returns an extra key

                - "target_labels" (the class label)
        """

        has_labels = False
        img = self.dataset[idx]
        if isinstance(img, (list, tuple)):
            has_labels = True
            img, labels = img

        if self.supervised and not has_labels:
            raise ValueError(
                "if TargetDataset is instantiated with supervised=True, the wrapped dataset must include labels"
            )

        item = {
            "target_imgs": img,
            "target_domain": self.domain,
            "target_sample_idx": idx,
        }

        if self.supervised:
            item["target_labels"] = labels

        return item

__getitem__(idx)

Returns:

Type Description
Dict[str, Any]

A dictionary with keys

  • "target_imgs" (the data)

  • "target_domain" (the integer representing the domain)

  • "target_sample_idx" (idx)

If supervised = True it returns an extra key

  • "target_labels" (the class label)
Source code in pytorch_adapt\datasets\target_dataset.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __getitem__(self, idx: int) -> Dict[str, Any]:
    """
    Returns:
        A dictionary with keys

            - "target_imgs" (the data)

            - "target_domain" (the integer representing the domain)

            - "target_sample_idx" (idx)

            If ```supervised = True``` it returns an extra key

            - "target_labels" (the class label)
    """

    has_labels = False
    img = self.dataset[idx]
    if isinstance(img, (list, tuple)):
        has_labels = True
        img, labels = img

    if self.supervised and not has_labels:
        raise ValueError(
            "if TargetDataset is instantiated with supervised=True, the wrapped dataset must include labels"
        )

    item = {
        "target_imgs": img,
        "target_domain": self.domain,
        "target_sample_idx": idx,
    }

    if self.supervised:
        item["target_labels"] = labels

    return item

__init__(dataset, domain=1, supervised=False)

Parameters:

Name Type Description Default
dataset Dataset

The dataset to wrap

required
domain int

An integer representing the domain.

1
supervised bool

A boolean for if the target dataset should return labels.

False
Source code in pytorch_adapt\datasets\target_dataset.py
17
18
19
20
21
22
23
24
25
def __init__(self, dataset: Dataset, domain: int = 1, supervised: bool = False):
    """
    Arguments:
        dataset: The dataset to wrap
        domain: An integer representing the domain.
        supervised: A boolean for if the target dataset should return labels.
    """
    super().__init__(dataset, domain)
    self.supervised = supervised