Skip to content

Commit b82f182

Browse files
authored
Merge pull request #336 from abhinavnayak11/patch-1
Updated calculation of train_loss and valid_loss
2 parents 5dcb0a8 + ec87c23 commit b82f182

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

convolutional-neural-networks/mnist-mlp/mnist_mlp_solution_with_validation.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@
418418
" # perform a single optimization step (parameter update)\n",
419419
" optimizer.step()\n",
420420
" # update running training loss\n",
421-
" train_loss += loss.item()*data.size(0)\n",
421+
" train_loss += loss.item()\n",
422422
" \n",
423423
" ###################### \n",
424424
" # validate the model #\n",
@@ -430,12 +430,12 @@
430430
" # calculate the loss\n",
431431
" loss = criterion(output, target)\n",
432432
" # update running validation loss \n",
433-
" valid_loss += loss.item()*data.size(0)\n",
433+
" valid_loss += loss.item()\n",
434434
" \n",
435435
" # print training/validation statistics \n",
436436
" # calculate average loss over an epoch\n",
437-
" train_loss = train_loss/len(train_loader.dataset)\n",
438-
" valid_loss = valid_loss/len(valid_loader.dataset)\n",
437+
" train_loss = train_loss/len(train_loader)\n",
438+
" valid_loss = valid_loss/len(valid_loader)\n",
439439
" \n",
440440
" print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n",
441441
" epoch+1, \n",

0 commit comments

Comments
 (0)