Skip to content

Commit 8837012

Browse files
Update helpers.py
added `model.train()` and `model.eval()`
1 parent 9a4dbec commit 8837012

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

weight-initialization/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def _get_loss_acc(model, train_loader, valid_loader):
2828
###################
2929
# train the model #
3030
###################
31+
model.train()
3132
for data, target in train_loader:
3233
# clear the gradients of all optimized variables
3334
optimizer.zero_grad()

0 commit comments

Comments
 (0)