Skip to content

mnistm

MNISTM

Bases: BaseDownloadableDataset

The dataset used in "Domain-Adversarial Training of Neural Networks". It consists of colored MNIST digits.

Extends BaseDownloadableDataset, so the dataset can be downloaded by setting download=True when initializing.

Source code in pytorch_adapt\datasets\mnistm.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class MNISTM(BaseDownloadableDataset):
    """
    The dataset used in "Domain-Adversarial Training of Neural Networks".
    It consists of colored MNIST digits.

    Extends [BaseDownloadableDataset][pytorch_adapt.datasets.BaseDownloadableDataset],
    so the dataset can be downloaded by setting ```download=True``` when
    initializing.
    """

    url = "https://cornell.box.com/shared/static/jado7quprg6hzzdubvwzh9umr75damwi"
    filename = "mnist_m.tar.gz"
    md5 = "859df31c91afe82e80e5012ba928f279"

    def __init__(self, root: str, train: bool, transform=None, **kwargs):
        """
        Arguments:
            root: The dataset must be located at ```<root>/mnist_m```
            train: Whether or not to use the training set.
            transform: The image transform applied to each sample.
        """
        self.train = check_train(train)
        super().__init__(root=root, domain="MNISTM", **kwargs)
        self.transform = transform

    def set_paths_and_labels(self, root):
        name = "train" if self.train else "test"
        labels_file = os.path.join(root, "mnist_m", f"mnist_m_{name}_labels.txt")
        img_dir = os.path.join(root, "mnist_m", f"mnist_m_{name}")
        with open(labels_file) as f:
            content = [line.rstrip().split(" ") for line in f]
        self.img_paths = [os.path.join(img_dir, x[0]) for x in content]
        check_length(self, {"train": 59001, "test": 9001}[name])
        self.labels = [int(x[1]) for x in content]

__init__(root, train, transform=None, **kwargs)

Parameters:

Name Type Description Default
root str

The dataset must be located at <root>/mnist_m

required
train bool

Whether or not to use the training set.

required
transform

The image transform applied to each sample.

None
Source code in pytorch_adapt\datasets\mnistm.py
21
22
23
24
25
26
27
28
29
30
def __init__(self, root: str, train: bool, transform=None, **kwargs):
    """
    Arguments:
        root: The dataset must be located at ```<root>/mnist_m```
        train: Whether or not to use the training set.
        transform: The image transform applied to each sample.
    """
    self.train = check_train(train)
    super().__init__(root=root, domain="MNISTM", **kwargs)
    self.transform = transform