Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 87bfac5

Browse files
author
Ryan Sepassi
committed
Add EarlyStoppingHook, PlateauOpHook, and MetricsBasedHook base class
PiperOrigin-RevId: 179860572
1 parent 758991d commit 87bfac5

File tree

8 files changed

+530
-22
lines changed

8 files changed

+530
-22
lines changed

tensor2tensor/bin/t2t-trainer

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,6 @@ def create_hparams():
7777

7878

7979
def create_experiment_fn():
80-
use_validation_monitor = (FLAGS.schedule in
81-
["train_and_evaluate", "continuous_train_and_eval"]
82-
and FLAGS.local_eval_frequency)
8380
return tpu_trainer_lib.create_experiment_fn(
8481
model_name=FLAGS.model,
8582
problem_name=get_problem_name(),
@@ -92,9 +89,9 @@ def create_experiment_fn():
9289
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
9390
use_tfdbg=FLAGS.tfdbg,
9491
use_dbgprofile=FLAGS.dbgprofile,
95-
use_validation_monitor=use_validation_monitor,
9692
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
9793
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
94+
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
9895
eval_early_stopping_metric_minimize=FLAGS.
9996
eval_early_stopping_metric_minimize,
10097
use_tpu=FLAGS.use_tpu)

tensor2tensor/bin/t2t_trainer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@ def create_hparams():
7676

7777

7878
def create_experiment_fn():
79-
use_validation_monitor = (FLAGS.schedule in
80-
["train_and_evaluate", "continuous_train_and_eval"]
81-
and FLAGS.local_eval_frequency)
8279
return tpu_trainer_lib.create_experiment_fn(
8380
model_name=FLAGS.model,
8481
problem_name=get_problem_name(),
@@ -91,9 +88,9 @@ def create_experiment_fn():
9188
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
9289
use_tfdbg=FLAGS.tfdbg,
9390
use_dbgprofile=FLAGS.dbgprofile,
94-
use_validation_monitor=use_validation_monitor,
9591
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
9692
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
93+
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
9794
eval_early_stopping_metric_minimize=FLAGS.
9895
eval_early_stopping_metric_minimize,
9996
use_tpu=FLAGS.use_tpu)

tensor2tensor/tpu/tpu_trainer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,6 @@ def create_hparams():
7676

7777

7878
def create_experiment_fn():
79-
use_validation_monitor = (FLAGS.schedule in
80-
["train_and_evaluate", "continuous_train_and_eval"]
81-
and FLAGS.local_eval_frequency)
8279
return tpu_trainer_lib.create_experiment_fn(
8380
model_name=FLAGS.model,
8481
problem_name=get_problem_name(),
@@ -91,9 +88,9 @@ def create_experiment_fn():
9188
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
9289
use_tfdbg=FLAGS.tfdbg,
9390
use_dbgprofile=FLAGS.dbgprofile,
94-
use_validation_monitor=use_validation_monitor,
9591
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
9692
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
93+
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
9794
eval_early_stopping_metric_minimize=FLAGS.
9895
eval_early_stopping_metric_minimize,
9996
use_tpu=FLAGS.use_tpu)

tensor2tensor/tpu/tpu_trainer_lib.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import os
23+
2224
# Dependency imports
2325

2426
from tensor2tensor.utils import devices
2527
from tensor2tensor.utils import expert_utils
28+
from tensor2tensor.utils import metrics_hook
2629
from tensor2tensor.utils import registry
2730
from tensor2tensor.utils import t2t_model
2831

@@ -186,7 +189,8 @@ def create_estimator(model_name,
186189

187190

188191
def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None,
189-
use_validation_monitor=False, validation_monitor_kwargs=None):
192+
use_validation_monitor=False, validation_monitor_kwargs=None,
193+
use_early_stopping=False, early_stopping_kwargs=None):
190194
"""Create train and eval hooks for Experiment."""
191195
train_monitors = []
192196
eval_hooks = []
@@ -208,6 +212,12 @@ def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None,
208212
tf.contrib.learn.monitors.ValidationMonitor(
209213
hooks=eval_hooks, **validation_monitor_kwargs))
210214

215+
if use_early_stopping:
216+
hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs)
217+
# Adding to both training and eval so that eval aborts as well
218+
train_monitors.append(hook)
219+
eval_hooks.append(hook)
220+
211221
return train_monitors, eval_hooks
212222

213223

@@ -224,9 +234,9 @@ def create_experiment(run_config,
224234
decode_hparams=None,
225235
use_tfdbg=False,
226236
use_dbgprofile=False,
227-
use_validation_monitor=False,
228237
eval_early_stopping_steps=None,
229238
eval_early_stopping_metric=None,
239+
eval_early_stopping_metric_delta=None,
230240
eval_early_stopping_metric_minimize=True,
231241
use_tpu=False):
232242
"""Create Experiment."""
@@ -264,12 +274,29 @@ def create_experiment(run_config,
264274
early_stopping_rounds=eval_early_stopping_steps,
265275
early_stopping_metric=eval_early_stopping_metric,
266276
early_stopping_metric_minimize=eval_early_stopping_metric_minimize)
277+
early_stopping_kwargs = dict(
278+
events_dir=os.path.join(run_config.model_dir, "eval_continuous"),
279+
tag=eval_early_stopping_metric,
280+
num_plateau_steps=eval_early_stopping_steps,
281+
plateau_decrease=eval_early_stopping_metric_minimize,
282+
plateau_delta=eval_early_stopping_metric_delta,
283+
every_n_steps=min_eval_frequency)
284+
285+
# In-process eval (and possible early stopping)
286+
local_schedules = ["train_and_evaluate", "continuous_train_and_eval"]
287+
use_validation_monitor = (
288+
schedule in local_schedules and min_eval_frequency)
289+
# Distributed early stopping
290+
use_early_stopping = (
291+
schedule not in local_schedules and eval_early_stopping_steps)
267292
train_monitors, eval_hooks = create_hooks(
268293
use_tfdbg=use_tfdbg,
269294
use_dbgprofile=use_dbgprofile,
270295
dbgprofile_kwargs=dbgprofile_kwargs,
271296
use_validation_monitor=use_validation_monitor,
272-
validation_monitor_kwargs=validation_monitor_kwargs)
297+
use_early_stopping=use_early_stopping,
298+
validation_monitor_kwargs=validation_monitor_kwargs,
299+
early_stopping_kwargs=early_stopping_kwargs)
273300
hooks_kwargs = {"train_monitors": train_monitors, "eval_hooks": eval_hooks}
274301

275302
# Experiment

tensor2tensor/tpu/tpu_trainer_lib_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def testExperiment(self):
6868
eval_steps=1,
6969
min_eval_frequency=1,
7070
use_tpu=False)
71-
run_config = tpu_trainer_lib.create_run_config(num_gpus=0, use_tpu=False)
71+
run_config = tpu_trainer_lib.create_run_config(
72+
model_dir=self.data_dir, num_gpus=0, use_tpu=False)
7273
hparams = registry.hparams("transformer_tiny_tpu")()
7374
exp = exp_fn(run_config, hparams)
7475
exp.test()

tensor2tensor/utils/flags.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@
5555
flags.DEFINE_integer("train_steps", 250000,
5656
"The number of steps to run training for.")
5757
flags.DEFINE_string("eval_early_stopping_metric", "loss",
58-
"If --schedule=train_and_evaluate and "
59-
"--eval_early_stopping_steps is not None, then stop when "
60-
"--eval_early_stopping_metric has not decreased for "
58+
"If --eval_early_stopping_steps is not None, then stop "
59+
"when --eval_early_stopping_metric has not decreased for "
6160
"--eval_early_stopping_steps")
61+
flags.DEFINE_float("eval_early_stopping_metric_delta", 0.1,
62+
"Delta determining whether metric has plateaued.")
6263
flags.DEFINE_integer("eval_early_stopping_steps", None,
63-
"If --schedule=train_and_evaluate and "
64-
"--eval_early_stopping_steps is not None, then stop when "
65-
"--eval_early_stopping_metric has not decreased for "
64+
"If --eval_early_stopping_steps is not None, then stop "
65+
"when --eval_early_stopping_metric has not decreased for "
6666
"--eval_early_stopping_steps")
6767
flags.DEFINE_bool("eval_early_stopping_metric_minimize", True,
6868
"Whether to check for the early stopping metric going down "

0 commit comments

Comments
 (0)