Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b10286e

Browse files
author
Ryan Sepassi
committed
Pad eval batch to enable multi-device eval; skip T2TModel.top if T2TModel.body returns training loss
PiperOrigin-RevId: 179882031
1 parent 45a4b88 commit b10286e

File tree

6 files changed

+55
-6
lines changed

6 files changed

+55
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.4.0',
8+
version='1.4.1',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

tensor2tensor/bin/t2t-trainer

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ try:
6161
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
6262
flags.DEFINE_string("schedule", "continuous_train_and_eval",
6363
"Method of Experiment to run.")
64-
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
64+
flags.DEFINE_integer("eval_steps", 10000,
65+
"Number of steps in evaluation. By default, eval will "
66+
"stop after eval_steps or when it runs through the eval "
67+
"dataset once in full, whichever comes first, so this "
68+
"can be a very large number.")
6569
except: # pylint: disable=bare-except
6670
pass
6771

tensor2tensor/bin/t2t_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@
6060
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
6161
flags.DEFINE_string("schedule", "continuous_train_and_eval",
6262
"Method of Experiment to run.")
63-
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
63+
flags.DEFINE_integer("eval_steps", 10000,
64+
"Number of steps in evaluation. By default, eval will "
65+
"stop after eval_steps or when it runs through the eval "
66+
"dataset once in full, whichever comes first, so this "
67+
"can be a very large number.")
6468
except: # pylint: disable=bare-except
6569
pass
6670

tensor2tensor/data_generators/problem.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,19 @@ def define_shapes(example):
576576
batching_scheme["boundaries"],
577577
batching_scheme["batch_sizes"])
578578

579+
if not is_training:
580+
def _pad_batch(features):
581+
if not config or config.data_parallelism.n <= 1:
582+
return features
583+
tf.logging.warn(
584+
"Padding the batch to ensure that remainder eval batches have "
585+
"a batch size divisible by the number of data shards. This may "
586+
"lead to incorrect metrics for non-zero-padded features, e.g. "
587+
"images. Use a single datashard (i.e. 1 GPU) in that case.")
588+
return pad_batch(features, config.data_parallelism.n)
589+
590+
dataset = dataset.map(_pad_batch, num_parallel_calls=num_threads)
591+
579592
dataset = dataset.map(define_shapes, num_parallel_calls=num_threads)
580593
dataset = dataset.prefetch(1)
581594
features = dataset.make_one_shot_iterator().get_next()
@@ -930,3 +943,23 @@ def standardize_shapes(features, batch_size=None):
930943
t.get_shape().assert_is_fully_defined()
931944

932945
return features
946+
947+
948+
def pad_batch(features, batch_multiple):
949+
"""Pad batch dim of features to nearest multiple of batch_multiple."""
950+
feature = features.items()[0][1]
951+
batch_size = tf.shape(feature)[0]
952+
mod = batch_size % batch_multiple
953+
has_mod = tf.cast(tf.cast(mod, tf.bool), tf.int32)
954+
batch_padding = batch_multiple * has_mod - mod
955+
956+
padded_features = {}
957+
for k, feature in features.items():
958+
rank = len(feature.shape)
959+
paddings = []
960+
for _ in range(rank):
961+
paddings.append([0, 0])
962+
paddings[0][1] = batch_padding
963+
padded_feature = tf.pad(feature, paddings)
964+
padded_features[k] = padded_feature
965+
return padded_features

tensor2tensor/tpu/tpu_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@
6060
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
6161
flags.DEFINE_string("schedule", "continuous_train_and_eval",
6262
"Method of Experiment to run.")
63-
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
63+
flags.DEFINE_integer("eval_steps", 10000,
64+
"Number of steps in evaluation. By default, eval will "
65+
"stop after eval_steps or when it runs through the eval "
66+
"dataset once in full, whichever comes first, so this "
67+
"can be a very large number.")
6468
except: # pylint: disable=bare-except
6569
pass
6670

tensor2tensor/utils/t2t_model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,15 @@ def model_fn_sharded(self, sharded_features):
139139
body_out = self.body_sharded(
140140
self._to_single_features_dict(transformed_features))
141141
body_out, losses = self._normalize_body_output(body_out)
142-
sharded_logits = dp(self.top, body_out, datashard_to_features)
143142
if "training" not in losses:
143+
sharded_logits = dp(self.top, body_out, datashard_to_features)
144144
sharded_losses = dp(self.loss, sharded_logits, datashard_to_features)
145145
training_loss_dict = average_sharded_losses([{
146146
"training": loss
147147
} for loss in sharded_losses])
148148
losses.update(training_loss_dict)
149+
else:
150+
sharded_logits = body_out
149151
else:
150152
sharded_logits, sharded_losses = dp(self.model_fn, datashard_to_features)
151153
losses = average_sharded_losses(sharded_losses)
@@ -172,9 +174,11 @@ def model_fn(self, features):
172174
body_out = self.body(transformed_features)
173175
output, losses = self._normalize_body_output(body_out)
174176

175-
logits = self.top(output, features)
176177
if "training" not in losses:
178+
logits = self.top(output, features)
177179
losses["training"] = self.loss(logits, features)
180+
else:
181+
logits = output
178182
return logits, losses
179183

180184
def bottom(self, features):

0 commit comments

Comments
 (0)