Skip to content

Commit 0f0a586

Browse files
committed
🚧 work in progress
1 parent 084c8d0 commit 0f0a586

File tree

3 files changed

+4153
-2452
lines changed

3 files changed

+4153
-2452
lines changed

assignment2/cifar10_exploration_dropout.ipynb

Lines changed: 4145 additions & 2450 deletions
Large diffs are not rendered by default.

assignment2/runner/estimate_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def estimate_model(model, X_train, y_train, X_val, y_val):
7676
'time': validation_time,
7777
}
7878

79-
# estimate how much time it would get to predict
79+
# estimate how much time it would get to prediction
8080
print('predict')
8181
validation_time = time.perf_counter()
8282
run_model(

assignment2/runner/model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import tensorflow as tf
77

88

9-
def run_model(sess, X, y, is_training, predict, loss_val, Xd, yd,
9+
def run_model(sess, X, y, is_training, predict, loss_val,
10+
Xd, yd,
1011
epochs=1, batch_size=64, print_every=100,
1112
training=None, plot_losses=False, learning_rate=None, learning_rate_value=10e-3, part_of_dataset=1.0,
1213
snapshot_name=None,
@@ -75,6 +76,11 @@ def run_model(sess, X, y, is_training, predict, loss_val, Xd, yd,
7576
# and (if given) perform a training step
7677
loss, corr, _ = sess.run(variables, feed_dict=feed_dict)
7778

79+
# TODO:
80+
# - we may want to calculate validation accuracy here
81+
# - maybe we need to store dynamic of accuracy (trainging) on each 10 (100) samples
82+
# or even each epoch
83+
7884
# aggregate performance stats
7985
losses.append(loss * actual_batch_size)
8086
correct += np.sum(corr)

0 commit comments

Comments
 (0)