Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e0fcdbc

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Add support for multivalent hparams, fix for device variable scope.
PiperOrigin-RevId: 181776339
1 parent 6bc1efb commit e0fcdbc

File tree

6 files changed

+274
-97
lines changed

6 files changed

+274
-97
lines changed

tensor2tensor/data_generators/problem.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,14 @@ def dataset(self,
478478
def _load_records(filename):
479479
return tf.data.TFRecordDataset(filename, buffer_size=16 * 1000 * 1000)
480480

481-
dataset = dataset.apply(
482-
tf.contrib.data.parallel_interleave(
483-
_load_records, sloppy=is_training, cycle_length=8))
481+
if hasattr(tf.contrib.data, "parallel_interleave"):
482+
interleave = lambda ds, fn: ds.apply( # pylint: disable=g-long-lambda
483+
tf.contrib.data.parallel_interleave(
484+
fn, sloppy=is_training, cycle_length=16))
485+
else:
486+
interleave = lambda ds, fn: ds.interleave(fn, cycle_length=16)
487+
488+
dataset = interleave(dataset, _load_records)
484489

485490
if repeat:
486491
dataset = dataset.repeat()
@@ -508,7 +513,7 @@ def _preprocess(example):
508513
dataset = dataset.map(self.decode_example, num_parallel_calls=num_threads)
509514

510515
if preprocess:
511-
dataset = dataset.flat_map(_preprocess)
516+
dataset = interleave(dataset, _preprocess)
512517

513518
dataset = dataset.map(
514519
_maybe_reverse_and_copy, num_parallel_calls=num_threads)
@@ -716,7 +721,7 @@ def _pad_batch(features):
716721
dataset = dataset.map(_pad_batch, num_parallel_calls=num_threads)
717722

718723
dataset = dataset.map(define_shapes, num_parallel_calls=num_threads)
719-
dataset = dataset.prefetch(1)
724+
dataset = dataset.prefetch(2)
720725
features = dataset.make_one_shot_iterator().get_next()
721726
if not config or not config.use_tpu:
722727
_summarize_features(features, (config and config.data_parallelism.n) or 1)

tensor2tensor/layers/common_hparams.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def basic_params1():
6161
weight_decay=0.1,
6262
weight_noise=0.0,
6363
learning_rate_decay_scheme="none",
64+
learning_rate_minimum=None,
65+
learning_rate_decay_rate=1.0,
6466
learning_rate_warmup_steps=100,
6567
learning_rate_cosine_cycle_steps=250000,
6668
learning_rate=0.1,

0 commit comments

Comments
 (0)