Skip to content

Commit 9a4dbec

Browse files
Include model.eval()
1 parent 661c38e commit 9a4dbec

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

weight-initialization/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _get_loss_acc(model, train_loader, valid_loader):
4545
# after training for 2 epochs, check validation accuracy
4646
correct = 0
4747
total = 0
48+
model.eval()
4849
for data, target in valid_loader:
4950
# forward pass: compute predicted outputs by passing inputs to the model
5051
output = model(data)
@@ -106,4 +107,4 @@ def hist_dist(title, distribution_tensor, hist_range=(-4, 4)):
106107
"""
107108
plt.title(title)
108109
plt.hist(distribution_tensor, np.linspace(*hist_range, num=len(distribution_tensor)//2))
109-
plt.show()
110+
plt.show()

0 commit comments

Comments
 (0)