Skip to content

Commit a60c71a

Browse files
Merge pull request #397 from salehsargolzaee/my-patch-1
Include model.train() and model.eval()
2 parents 661c38e + 8837012 commit a60c71a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

weight-initialization/helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def _get_loss_acc(model, train_loader, valid_loader):
2828
###################
2929
# train the model #
3030
###################
31+
model.train()
3132
for data, target in train_loader:
3233
# clear the gradients of all optimized variables
3334
optimizer.zero_grad()
@@ -45,6 +46,7 @@ def _get_loss_acc(model, train_loader, valid_loader):
4546
# after training for 2 epochs, check validation accuracy
4647
correct = 0
4748
total = 0
49+
model.eval()
4850
for data, target in valid_loader:
4951
# forward pass: compute predicted outputs by passing inputs to the model
5052
output = model(data)
@@ -106,4 +108,4 @@ def hist_dist(title, distribution_tensor, hist_range=(-4, 4)):
106108
"""
107109
plt.title(title)
108110
plt.hist(distribution_tensor, np.linspace(*hist_range, num=len(distribution_tensor)//2))
109-
plt.show()
111+
plt.show()

0 commit comments

Comments
 (0)