Skip to content

base_dataset

BaseDataset

Bases: torch.utils.data.Dataset

Base dataset class

Source code in pytorch_adapt\datasets\base_dataset.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class BaseDataset(torch.utils.data.Dataset):
    """
    Base dataset class
    """

    def __init__(self, domain: str):
        super().__init__()
        self.domain = domain

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        img = Image.open(self.img_paths[idx]).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __repr__(self):
        extra_repr = f"domain={self.domain}\nlen={str(self.__len__())}"
        return c_f.nice_repr(self, extra_repr, {"transform": self.transform})

BaseDownloadableDataset

Bases: BaseDataset

Allows automatic downloading of datasets.

Source code in pytorch_adapt\datasets\base_dataset.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class BaseDownloadableDataset(BaseDataset):
    """
    Allows automatic downloading of datasets.
    """

    def __init__(self, root: str, download: bool = False, **kwargs):
        """
        Arguments:
            root: Folder where dataset will be downloaded to.
            download: If True, will download the dataset if it hasn't already been downloaded.
        """
        super().__init__(**kwargs)
        if download:
            try:
                self.set_paths_and_labels(root)
            except (FileNotFoundError, ValueError):
                self.download_dataset(root)
                self.set_paths_and_labels(root)
        else:
            self.set_paths_and_labels(root)

    def set_paths_and_labels(self, root):
        raise NotImplementedError

    def download_dataset(self, root):
        download_url(self.url, root, filename=self.filename, md5=self.md5)
        filepath = os.path.join(root, self.filename)
        decompressor = tarfile.open if tarfile.is_tarfile(filepath) else zipfile.ZipFile
        c_f.LOGGER.info("Extracting")
        with decompressor(filepath, "r") as f:
            f.extractall(path=root, members=c_f.extract_progress(f))

__init__(root, download=False, **kwargs)

Parameters:

Name Type Description Default
root str

Folder where dataset will be downloaded to.

required
download bool

If True, will download the dataset if it hasn't already been downloaded.

False
Source code in pytorch_adapt\datasets\base_dataset.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def __init__(self, root: str, download: bool = False, **kwargs):
    """
    Arguments:
        root: Folder where dataset will be downloaded to.
        download: If True, will download the dataset if it hasn't already been downloaded.
    """
    super().__init__(**kwargs)
    if download:
        try:
            self.set_paths_and_labels(root)
        except (FileNotFoundError, ValueError):
            self.download_dataset(root)
            self.set_paths_and_labels(root)
    else:
        self.set_paths_and_labels(root)