File tree Expand file tree Collapse file tree 3 files changed +7
-7
lines changed
convolutional-neural-networks/mnist-mlp Expand file tree Collapse file tree 3 files changed +7
-7
lines changed Original file line number Diff line number Diff line change 260260 " \n " ,
261261 " # print training statistics \n " ,
262262 " # calculate average loss over an epoch\n " ,
263- " train_loss = train_loss/len(train_loader.dataset )\n " ,
263+ " train_loss = train_loss/len(train_loader.sampler )\n " ,
264264 " \n " ,
265265 " print('Epoch: {} \\ tTraining Loss: {:.6f}'.format(\n " ,
266266 " epoch+1, \n " ,
315315 " class_total[label] += 1\n " ,
316316 " \n " ,
317317 " # calculate and print avg test loss\n " ,
318- " test_loss = test_loss/len(test_loader.dataset )\n " ,
318+ " test_loss = test_loss/len(test_loader.sampler )\n " ,
319319 " print('Test Loss: {:.6f}\\ n'.format(test_loss))\n " ,
320320 " \n " ,
321321 " for i in range(10):\n " ,
Original file line number Diff line number Diff line change 360360 " \n " ,
361361 " # print training statistics \n " ,
362362 " # calculate average loss over an epoch\n " ,
363- " train_loss = train_loss/len(train_loader.dataset )\n " ,
363+ " train_loss = train_loss/len(train_loader.sampler )\n " ,
364364 " \n " ,
365365 " print('Epoch: {} \\ tTraining Loss: {:.6f}'.format(\n " ,
366366 " epoch+1, \n " ,
430430 " class_total[label] += 1\n " ,
431431 " \n " ,
432432 " # calculate and print avg test loss\n " ,
433- " test_loss = test_loss/len(test_loader.dataset )\n " ,
433+ " test_loss = test_loss/len(test_loader.sampler )\n " ,
434434 " print('Test Loss: {:.6f}\\ n'.format(test_loss))\n " ,
435435 " \n " ,
436436 " for i in range(10):\n " ,
Original file line number Diff line number Diff line change 417417 " \n " ,
418418 " # print training/validation statistics \n " ,
419419 " # calculate average loss over an epoch\n " ,
420- " train_loss = train_loss/len(train_loader.dataset )\n " ,
421- " valid_loss = valid_loss/len(valid_loader.dataset )\n " ,
420+ " train_loss = train_loss/len(train_loader.sampler )\n " ,
421+ " valid_loss = valid_loss/len(valid_loader.sampler )\n " ,
422422 " \n " ,
423423 " print('Epoch: {} \\ tTraining Loss: {:.6f} \\ tValidation Loss: {:.6f}'.format(\n " ,
424424 " epoch+1, \n " ,
515515 " class_total[label] += 1\n " ,
516516 " \n " ,
517517 " # calculate and print avg test loss\n " ,
518- " test_loss = test_loss/len(test_loader.dataset )\n " ,
518+ " test_loss = test_loss/len(test_loader.sampler )\n " ,
519519 " print('Test Loss: {:.6f}\\ n'.format(test_loss))\n " ,
520520 " \n " ,
521521 " for i in range(10):\n " ,
You can’t perform that action at this time.
0 commit comments