Skip to content

Commit 6ba1dff

Browse files
committed
fix normalization of loss calculations
1 parent 752f70b commit 6ba1dff

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

convolutional-neural-networks/mnist-mlp/mnist_mlp_exercise.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@
260260
" \n",
261261
" # print training statistics \n",
262262
" # calculate average loss over an epoch\n",
263-
" train_loss = train_loss/len(train_loader.dataset)\n",
263+
" train_loss = train_loss/len(train_loader.sampler)\n",
264264
"\n",
265265
" print('Epoch: {} \\tTraining Loss: {:.6f}'.format(\n",
266266
" epoch+1, \n",
@@ -315,7 +315,7 @@
315315
" class_total[label] += 1\n",
316316
"\n",
317317
"# calculate and print avg test loss\n",
318-
"test_loss = test_loss/len(test_loader.dataset)\n",
318+
"test_loss = test_loss/len(test_loader.sampler)\n",
319319
"print('Test Loss: {:.6f}\\n'.format(test_loss))\n",
320320
"\n",
321321
"for i in range(10):\n",

convolutional-neural-networks/mnist-mlp/mnist_mlp_solution.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@
360360
" \n",
361361
" # print training statistics \n",
362362
" # calculate average loss over an epoch\n",
363-
" train_loss = train_loss/len(train_loader.dataset)\n",
363+
" train_loss = train_loss/len(train_loader.sampler)\n",
364364
"\n",
365365
" print('Epoch: {} \\tTraining Loss: {:.6f}'.format(\n",
366366
" epoch+1, \n",
@@ -430,7 +430,7 @@
430430
" class_total[label] += 1\n",
431431
"\n",
432432
"# calculate and print avg test loss\n",
433-
"test_loss = test_loss/len(test_loader.dataset)\n",
433+
"test_loss = test_loss/len(test_loader.sampler)\n",
434434
"print('Test Loss: {:.6f}\\n'.format(test_loss))\n",
435435
"\n",
436436
"for i in range(10):\n",

convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,8 @@
417417
" \n",
418418
" # print training/validation statistics \n",
419419
" # calculate average loss over an epoch\n",
420-
" train_loss = train_loss/len(train_loader.dataset)\n",
421-
" valid_loss = valid_loss/len(valid_loader.dataset)\n",
420+
" train_loss = train_loss/len(train_loader.sampler)\n",
421+
" valid_loss = valid_loss/len(valid_loader.sampler)\n",
422422
" \n",
423423
" print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n",
424424
" epoch+1, \n",
@@ -515,7 +515,7 @@
515515
" class_total[label] += 1\n",
516516
"\n",
517517
"# calculate and print avg test loss\n",
518-
"test_loss = test_loss/len(test_loader.dataset)\n",
518+
"test_loss = test_loss/len(test_loader.sampler)\n",
519519
"print('Test Loss: {:.6f}\\n'.format(test_loss))\n",
520520
"\n",
521521
"for i in range(10):\n",

0 commit comments

Comments
 (0)