Bases: nn.Module
A small convnet for extracting features
from MNIST.
Source code in pytorch_adapt\models\mnist.py
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29 | class MNISTFeatures(nn.Module):
"""
A small convnet for extracting features
from MNIST.
"""
def __init__(self):
""" """
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 5, 1)
self.conv2 = nn.Conv2d(32, 48, 5, 1)
self.fc = nn.Identity()
def forward(self, x):
""" """
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = torch.flatten(x, start_dim=1)
return self.fc(x)
|