Skip to content

combined_source_and_target

CombinedSourceAndTargetDataset

Bases: torch.utils.data.Dataset

Wraps a source dataset and a target dataset.

Source code in pytorch_adapt\datasets\combined_source_and_target.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class CombinedSourceAndTargetDataset(torch.utils.data.Dataset):
    """
    Wraps a source dataset and a target dataset.
    """

    def __init__(self, source_dataset: SourceDataset, target_dataset: TargetDataset):
        """
        Arguments:
            source_dataset:
            target_dataset:
        """

        self.source_dataset = source_dataset
        self.target_dataset = target_dataset

    def __len__(self) -> int:
        """
        Returns:
            The length of the target dataset.
        """
        return len(self.target_dataset)

    def __getitem__(self, idx) -> Dict[str, Any]:
        """
        Arguments:
            idx: The index of the target dataset. The source index is picked randomly.
        Returns:
            A dictionary containing both source and target data.
                The source keys start with ```"src_"```, and the target keys start with ```"target_"```.
                See [```SourceDataset.__getitem__```][pytorch_adapt.datasets.SourceDataset.__getitem__] and
                [```TargetDataset.__getitem__```][pytorch_adapt.datasets.TargetDataset.__getitem__]
                for details.
        """
        target_data = self.target_dataset[idx]
        src_data = self.source_dataset[self.get_random_src_idx()]
        return c_f.assert_dicts_are_disjoint(src_data, target_data)

    def get_random_src_idx(self):
        return np.random.choice(len(self.source_dataset))

    def __repr__(self):
        return c_f.nice_repr(
            self,
            "",
            {
                "source_dataset": self.source_dataset,
                "target_dataset": self.target_dataset,
            },
        )

__getitem__(idx)

Parameters:

Name Type Description Default
idx

The index of the target dataset. The source index is picked randomly.

required

Returns:

Type Description
Dict[str, Any]

A dictionary containing both source and target data. The source keys start with "src_", and the target keys start with "target_". See SourceDataset.__getitem__ and TargetDataset.__getitem__ for details.

Source code in pytorch_adapt\datasets\combined_source_and_target.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __getitem__(self, idx) -> Dict[str, Any]:
    """
    Arguments:
        idx: The index of the target dataset. The source index is picked randomly.
    Returns:
        A dictionary containing both source and target data.
            The source keys start with ```"src_"```, and the target keys start with ```"target_"```.
            See [```SourceDataset.__getitem__```][pytorch_adapt.datasets.SourceDataset.__getitem__] and
            [```TargetDataset.__getitem__```][pytorch_adapt.datasets.TargetDataset.__getitem__]
            for details.
    """
    target_data = self.target_dataset[idx]
    src_data = self.source_dataset[self.get_random_src_idx()]
    return c_f.assert_dicts_are_disjoint(src_data, target_data)

__init__(source_dataset, target_dataset)

Parameters:

Name Type Description Default
source_dataset SourceDataset required
target_dataset TargetDataset required
Source code in pytorch_adapt\datasets\combined_source_and_target.py
16
17
18
19
20
21
22
23
24
def __init__(self, source_dataset: SourceDataset, target_dataset: TargetDataset):
    """
    Arguments:
        source_dataset:
        target_dataset:
    """

    self.source_dataset = source_dataset
    self.target_dataset = target_dataset

__len__()

Returns:

Type Description
int

The length of the target dataset.

Source code in pytorch_adapt\datasets\combined_source_and_target.py
26
27
28
29
30
31
def __len__(self) -> int:
    """
    Returns:
        The length of the target dataset.
    """
    return len(self.target_dataset)