@@ -82,3 +82,67 @@ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/0fa674e43
8282
8383## Step 3 : Training the model on the dataset
8484
85+ We have used SGD as Optimization Algorithm here with learning rate (lr) = 0.003 and momentum = 0.9 as suggested in general sense. [ Typical lr values range from 0.0001 up to 1 and it is upon us to find a suitable value by cross validation
86+
87+ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L33
88+
89+ To calculate the total training time, time module has been used. (Lines 34 and 48)
90+
91+ Trial and Error method can be used to find the suitable epoch value, for this code, it has been setup to be 18
92+
93+ Overall Training is being done as :
94+
95+ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L33-L49
96+
97+ ## Step 4 : Testing the Model
98+
99+ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L51-L66
100+
101+ ## Step 5 : Saving the model
102+
103+ https://github.com/infinitecoder1729/mnist-dataset-classification/blob/a014ffaeead36b9a8d1458b51b6f70fc3d8873e3/MNIST%20Classification%20Model..py#L68
104+
105+ ## To View results for any random picture in the dataset, the following code can be used :
106+
107+ It also creates a graph displaying the probabilities returned by the model.
108+
109+ ``` py
110+ import numpy as np
111+ def view_classify (img , ps ):
112+ ps = ps.cpu().data.numpy().squeeze()
113+ fig, (ax1, ax2) = plt.subplots(figsize = (6 ,9 ), ncols = 2 )
114+ ax1.imshow(img.resize_(1 , 28 , 28 ).numpy().squeeze())
115+ ax1.axis(' off' )
116+ ax2.barh(np.arange(10 ), ps)
117+ ax2.set_aspect(0.1 )
118+ ax2.set_yticks(np.arange(10 ))
119+ ax2.set_yticklabels(np.arange(10 ))
120+ ax2.set_title(' Class Probability' )
121+ ax2.set_xlim(0 , 1.1 )
122+ plt.tight_layout()
123+ img,label= train[np.random.randint(0 ,10001 )]
124+ image= img.view(1 , 784 )
125+ with tch.no_grad():
126+ logps = model(image)
127+ ps = tch.exp(logps)
128+ probab = list (ps.numpy()[0 ])
129+ print (" Predicted Digit =" , probab.index(max (probab)))
130+ view_classify(image.view(1 , 28 , 28 ), ps)
131+ ```
132+
133+ ### Examples :
134+
135+ ![ image] ( https://user-images.githubusercontent.com/77016507/225422901-908e96de-629f-4d33-b7ba-819960a97d66.png )
136+
137+ ![ image] ( https://user-images.githubusercontent.com/77016507/225423008-3f858a52-2331-48e1-b271-f6d6e25e2d91.png )
138+
139+ ![ image] ( https://user-images.githubusercontent.com/77016507/225423232-d0249b38-e191-495d-b9fd-8c32eb20da57.png )
140+
141+ ### Model Accuracy : The Accuracy of the model with this code is approximately 97.8% to 98.02% with a training time of aprox. 3.5 to 4 minutes
142+
143+ ## Further Improvements :
144+
145+ 1 . Working on making graphical representation of useful data such as Loss vs Epoch Number etc.
146+ 2 . Looking to test with different algorithms to strike a balance between training time and accuracy.
147+
148+ ### Contributions, Suggestions, and inputs on graphical representation for better understanding are welcome.
0 commit comments