Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9394d0e

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Use ModeKeys enum consistently in trainer_utils instead of string literals.
PiperOrigin-RevId: 164008619
1 parent 4390618 commit 9394d0e

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tensor2tensor/utils/trainer_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ def create_experiment(output_dir, data_dir, model_name, train_steps,
181181
eval_hooks.append(hook)
182182
return tf.contrib.learn.Experiment(
183183
estimator=estimator,
184-
train_input_fn=input_fns["train"],
185-
eval_input_fn=input_fns["eval"],
184+
train_input_fn=input_fns[tf.contrib.learn.ModeKeys.TRAIN],
185+
eval_input_fn=input_fns[tf.contrib.learn.ModeKeys.EVAL],
186186
eval_metrics=eval_metrics,
187187
train_steps=train_steps,
188188
eval_steps=eval_steps,
@@ -220,7 +220,9 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name):
220220
keep_checkpoint_max=FLAGS.keep_checkpoint_max))
221221
# Store the hparams in the estimator as well
222222
estimator.hparams = hparams
223-
return estimator, {"train": train_input_fn, "eval": eval_input_fn}
223+
return estimator, {
224+
tf.contrib.learn.ModeKeys.TRAIN: train_input_fn,
225+
tf.contrib.learn.ModeKeys.EVAL: eval_input_fn}
224226

225227

226228
def log_registry():

0 commit comments

Comments
 (0)