|
418 | 418 | " # perform a single optimization step (parameter update)\n", |
419 | 419 | " optimizer.step()\n", |
420 | 420 | " # update running training loss\n", |
421 | | - " train_loss += loss.item()*data.size(0)\n", |
| 421 | + " train_loss += loss.item()\n", |
422 | 422 | " \n", |
423 | 423 | " ###################### \n", |
424 | 424 | " # validate the model #\n", |
|
430 | 430 | " # calculate the loss\n", |
431 | 431 | " loss = criterion(output, target)\n", |
432 | 432 | " # update running validation loss \n", |
433 | | - " valid_loss += loss.item()*data.size(0)\n", |
| 433 | + " valid_loss += loss.item()\n", |
434 | 434 | " \n", |
435 | 435 | " # print training/validation statistics \n", |
436 | 436 | " # calculate average loss over an epoch\n", |
437 | | - " train_loss = train_loss/len(train_loader.dataset)\n", |
438 | | - " valid_loss = valid_loss/len(valid_loader.dataset)\n", |
| 437 | + " train_loss = train_loss/len(train_loader)\n", |
| 438 | + " valid_loss = valid_loss/len(valid_loader)\n", |
439 | 439 | " \n", |
440 | 440 | " print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n", |
441 | 441 | " epoch+1, \n", |
|
0 commit comments