diff --git a/bin/train.py b/bin/train.py index 1da17e6..61e9d1e 100644 --- a/bin/train.py +++ b/bin/train.py @@ -30,6 +30,7 @@ flags.DEFINE_boolean('stop', True, 'Stop aws instance after finished running.') flags.DEFINE_float('min_delta', 0.005, 'Early stopping minimum change value.') flags.DEFINE_integer('patience', 20, 'Early stopping epochs patience to wait before stopping.') +flags.DEFINE_boolean('augment', False, 'Whether open augmented data or initial data.') def main(_): config = BayesianConfig(FLAGS.encoder, FLAGS.dataset, FLAGS.batch_size, FLAGS.epochs, FLAGS.monte_carlo_simulations) @@ -37,7 +38,7 @@ def main(_): min_image_size = encoder_min_input_size(FLAGS.encoder) - ((x_train, y_train), (x_test, y_test)) = test_train_batch_data(FLAGS.dataset, FLAGS.encoder, FLAGS.debug, augment_data=True) + ((x_train, y_train), (x_test, y_test)) = test_train_batch_data(FLAGS.dataset, FLAGS.encoder, FLAGS.debug, FLAGS.augment) min_image_size = list(min_image_size) min_image_size.append(3)