11import torch
22import torch .nn as nn
33import torch .nn .functional as F
4-
4+ from utils import Lambda
55
66class EasyMnist (nn .Module ):
77
@@ -13,11 +13,49 @@ def __init__(self):
1313
1414 def forward (self , x_batch : torch .Tensor ):
1515 """Simple ReLU-based activations through all layers of the DNN.
16- Simple and effectively deep neural network. No frills.
16+ Simple and sufficiently deep neural network. No frills.
1717 """
1818 _input = x_batch .view (- 1 , 784 ) # shape for our linear1
1919 out1 = F .relu (self .linear1 (x_batch ))
2020 out2 = F .relu (self .linear2 (out1 ))
2121 out3 = F .relu (self .linear3 (out2 ))
2222
23- return out3
23+ return out3
24+
25+
26+ # for comparison with the above
27+ def EasyMnistSeq ():
28+ return nn .Sequential (
29+ Lambda (lambda x : x .reshape (- 1 , 784 )),
30+ nn .Linear (784 , 1000 ),
31+ nn .Relu (),
32+ nn .Linear (1000 , 300 ),
33+ nn .Relu (),
34+ nn .Linear (300 , 10 ),
35+ nn .Relu (),
36+ )
37+
38+
39+ class MnistConvNet (nn .Module ):
40+ def __init__ (self , interim_size = 16 ):
41+ """
42+ A simple and shallow deep CNN to show that morph will shrink this architecture,
43+ which will inherently be wasteful on the task of classifying MNIST digits with
44+ accuracy above 95%.
45+ By default produces a 1x16 -> 16x16 -> 16x10 convnet
46+ """
47+ super ().__init__ ()
48+ self .conv1 = nn .Conv2d (1 , interim_size , kernel_size = 3 , stride = 2 , padding = 1 )
49+ self .conv2 = nn .Conv2d (interim_size , interim_size , kernel_size = 3 , stride = 2 , padding = 1 )
50+ self .conv3 = nn .Conv2d (interim_size , 10 , kernel_size = 3 , stride = 2 , padding = 1 )
51+
52+ def forward (self , xb ):
53+ xb = xb .view (- 1 , 1 , 28 , 28 ) # any batch_size, 1 channel, 28x28 pixels
54+ xb = F .relu (self .conv1 (xb ))
55+ xb = F .relu (self .conv2 (xb ))
56+ xb = F .relu (self .conv3 (xb ))
57+ xb = F .avg_pool2d (xb , 4 )
58+
59+ # reshape the output to the second dimension of the pool size, and just fill the rest to whatever.
60+ return xb .view (- 1 , xb .size (1 ))
61+
0 commit comments