Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 20d7919

Browse files
nshazeerRyan Sepassi
authored andcommitted
Add hparams.split_to_length option for chopping long fixed-length language modeling examples on data read. Make logic more explicit for whether to interpret batch_size as examples or as tokens. Some changes to ChoppedTextProblem.
PiperOrigin-RevId: 181631661
1 parent 17ce194 commit 20d7919

File tree

4 files changed

+189
-63
lines changed

4 files changed

+189
-63
lines changed

tensor2tensor/bin/t2t_datagen.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@
6767
"If true, we only list the problems that will be generated.")
6868
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
6969
flags.DEFINE_integer("task_id", -1, "For distributed data generation.")
70+
flags.DEFINE_integer("task_id_start", -1, "For distributed data generation.")
71+
flags.DEFINE_integer("task_id_end", -1, "For distributed data generation.")
7072
flags.DEFINE_integer(
71-
"num_concurrent_processes", 34,
73+
"num_concurrent_processes", 10,
7274
"Applies only to problems for which multiprocess_generate=True.")
7375
flags.DEFINE_string("t2t_usr_dir", "",
7476
"Path to a Python module that will be imported. The "
@@ -203,7 +205,6 @@ def generate_data_in_process(arg):
203205
problem_name, data_dir, tmp_dir, task_id = arg
204206
problem = registry.problem(problem_name)
205207
problem.generate_data(data_dir, tmp_dir, task_id)
206-
# return 0
207208

208209

209210
def generate_data_for_registered_problem(problem_name):
@@ -215,10 +216,17 @@ def generate_data_for_registered_problem(problem_name):
215216
data_dir = os.path.expanduser(FLAGS.data_dir)
216217
tmp_dir = os.path.expanduser(FLAGS.tmp_dir)
217218
if task_id is None and problem.multiprocess_generate:
218-
problem.prepare_to_generate(data_dir, tmp_dir)
219+
if FLAGS.task_id_start != -1:
220+
assert FLAGS.task_id_end != -1
221+
task_id_start = FLAGS.task_id_start
222+
task_id_end = FLAGS.task_id_end
223+
else:
224+
task_id_start = 0
225+
task_id_end = problem.num_generate_tasks
219226
pool = multiprocessing.Pool(processes=FLAGS.num_concurrent_processes)
227+
problem.prepare_to_generate(data_dir, tmp_dir)
220228
args = [(problem_name, data_dir, tmp_dir, task_id)
221-
for task_id in range(problem.num_generate_tasks)]
229+
for task_id in range(task_id_start, task_id_end)]
222230
pool.map(generate_data_in_process, args)
223231
else:
224232
problem.generate_data(data_dir, tmp_dir, task_id)

tensor2tensor/data_generators/problem.py

Lines changed: 142 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def default_model_hparams():
102102
max_input_seq_length=0,
103103
max_target_seq_length=0,
104104
prepend_mode="none",
105+
split_to_length=0,
105106
data_dir=None)
106107

107108

@@ -117,6 +118,12 @@ def preprocess_example_common(example, hparams, mode):
117118
else:
118119
example["targets"] = tf.concat(
119120
[example["inputs"], [0], example["targets"]], 0)
121+
if hparams.split_to_length:
122+
example["targets"] = tf.reshape(
123+
example["targets"], [-1, hparams.split_to_length, 1, 1])
124+
if len(example) != 1:
125+
raise ValueError("split_to_length only works for LM problems")
126+
return tf.data.Dataset.from_tensor_slices(example)
120127
return example
121128

122129

@@ -232,7 +239,29 @@ def max_length(self, model_hparams):
232239
Returns:
233240
an integer
234241
"""
235-
return model_hparams.max_length or model_hparams.batch_size
242+
return (
243+
model_hparams.split_to_length or
244+
model_hparams.max_length or
245+
model_hparams.batch_size)
246+
247+
@property
248+
def batch_size_means_tokens(self):
249+
"""Do we specify hparams.batch_size in tokens per datashard per batch.
250+
251+
This is generally done for text problems.
252+
253+
If False, we assume that batch sizes are specified in examples per
254+
datashard per batch.
255+
256+
TODO(noam): we should be more explicit and replace the hyperparameter
257+
batch size with two hyperparameters:
258+
hparams.examples_per_batch_per_datashard
259+
hparams.tokens_per_batch_per_datashard
260+
261+
Returns:
262+
a boolean
263+
"""
264+
return False
236265

237266
def dataset_filename(self):
238267
return self.name
@@ -620,23 +649,39 @@ def define_shapes(example):
620649
if is_training:
621650
dataset = dataset.repeat(None)
622651

652+
if self.batch_size_means_tokens:
653+
batch_size_means_tokens = True
654+
else:
655+
if _are_shapes_fully_defined(dataset.output_shapes):
656+
batch_size_means_tokens = False
657+
else:
658+
tf.logging.warning(
659+
"Shapes are not fully defined. Assuming batch_size means tokens. "
660+
"You should probably override batch_size_means_tokens() "
661+
"in your problem subclass")
662+
batch_size_means_tokens = True
663+
623664
# Batching
624-
if _are_shapes_fully_defined(dataset.output_shapes):
625-
# Static shape features (e.g. images)
665+
if not batch_size_means_tokens:
666+
# Batch size means examples per datashard.
626667
if config and config.use_tpu:
668+
# on TPU, we use params["batch_size"], which specifies the number of
669+
# examples across all datashards
627670
tpu_batch_size = params["batch_size"]
628671
dataset = dataset.apply(
629672
tf.contrib.data.batch_and_drop_remainder(tpu_batch_size))
630673
else:
631674
num_shards = (config and config.data_parallelism.n) or 1
632675
dataset = dataset.batch(hparams.batch_size * num_shards)
633676
else:
634-
# Variable length features
677+
# batch_size means tokens per datashard
635678
if config and config.use_tpu:
636679
# On TPU, pad to max_length
637680
dataset = dataset.filter(tpu_valid_size)
638681
padded_shapes = _fill_shape_nones(
639682
dataset.output_shapes, none_filler=max_length)
683+
# on TPU, we use params["batch_size"], which specifies the number of
684+
# examples across all datashards
640685
dataset = dataset.apply(
641686
tf.contrib.data.padded_batch_and_drop_remainder(
642687
params["batch_size"], padded_shapes))
@@ -648,6 +693,7 @@ def define_shapes(example):
648693
shard_multiplier=(config and config.data_parallelism.n) or 1,
649694
length_multiplier=self.get_hparams().batch_size_multiplier)
650695
if hparams.use_fixed_batch_size:
696+
# Here batch_size really means examples per datashard.
651697
batching_scheme["batch_sizes"] = [hparams.batch_size]
652698
batching_scheme["boundaries"] = []
653699
dataset = data_reader.bucket_by_sequence_length(
@@ -818,6 +864,10 @@ def is_character_level(self):
818864
def targeted_vocab_size(self):
819865
raise NotImplementedError() # Not needed if self.is_character_level.
820866

867+
@property
868+
def batch_size_means_tokens(self):
869+
return True
870+
821871
def generator(self, data_dir, tmp_dir, is_training):
822872
"""Generator for the training and evaluation data.
823873
@@ -981,14 +1031,14 @@ class ChoppedTextProblem(Text2TextProblem):
9811031
"""Tokenize and chop text files into fixed-length language-modeling examples.
9821032
9831033
The input data is a set of text files, as specified by
984-
self.train_text_filenames() and self.dev_text_filenames().
1034+
self.train_text_filepaths() and self.dev_text_filepaths().
9851035
9861036
The text is tokenized using a SubwordTextEncoder, and
9871037
then split into examples, each of length self.sequence_length().
9881038
"""
9891039

990-
def train_text_filenames(self, tmp_dir):
991-
"""Local filenames of text files containing training data.
1040+
def train_text_filepaths(self, tmp_dir):
1041+
"""Local filepaths of text files containing training data.
9921042
9931043
This function may want to download the files if they do not exist.
9941044
@@ -999,8 +1049,8 @@ def train_text_filenames(self, tmp_dir):
9991049
"""
10001050
raise NotImplementedError()
10011051

1002-
def dev_text_filenames(self, tmp_dir):
1003-
"""Local filenames of text files containing dev data.
1052+
def dev_text_filepaths(self, tmp_dir):
1053+
"""Local filepaths of text files containing dev data.
10041054
10051055
This function may want to download the files if they do not exist.
10061056
@@ -1016,15 +1066,15 @@ def sequence_length(self):
10161066
"""Length of each example (in tokens)."""
10171067
raise NotImplementedError()
10181068

1019-
def max_length(self, unused_model_hparams):
1020-
return self.sequence_length
1069+
def max_length(self, model_hparams):
1070+
return model_hparams.split_to_length or self.sequence_length
10211071

10221072
@property
10231073
def is_character_level(self):
10241074
return False
10251075

1026-
def text_filenames_for_task(self, tmp_dir, task_id):
1027-
"""List of input filenames for a particular training or dev shard.
1076+
def text_filepaths_for_task(self, tmp_dir, task_id):
1077+
"""List of input filepaths for a particular training or dev shard.
10281078
10291079
Args:
10301080
tmp_dir: a string
@@ -1035,49 +1085,69 @@ def text_filenames_for_task(self, tmp_dir, task_id):
10351085
assert task_id >= 0
10361086
assert task_id < self.num_train_shards + self.num_dev_shards
10371087
if task_id < self.num_train_shards:
1038-
return [f for i, f in enumerate(self.train_text_filenames(tmp_dir))
1088+
return [f for i, f in enumerate(self.train_text_filepaths(tmp_dir))
10391089
if i % self.num_train_shards == task_id]
10401090
else:
1041-
return [f for i, f in enumerate(self.dev_text_filenames(tmp_dir))
1091+
return [f for i, f in enumerate(self.dev_text_filepaths(tmp_dir))
10421092
if i % self.num_dev_shards == task_id - self.num_train_shards]
10431093

1044-
def filename_to_unicode_text(self, filename):
1094+
def filepath_to_unicode_strings(self, filepath):
10451095
"""Read text out of an input file.
10461096
1047-
The default just reads the text and converts to unicode.
1097+
The default just reads the text, converts to unicode and yields one
1098+
unicode string.
10481099
1049-
Subclasses can override this function in order to preprocess.
1100+
Subclasses can override this function in order to preprocess, and can
1101+
yield any number of strings.
10501102
10511103
Args:
1052-
filename: a string
1053-
Returns:
1054-
a unicode string.
1104+
filepath: a string
1105+
Yields:
1106+
unicode strings.
10551107
"""
1056-
f = tf.gfile.Open(filename)
1108+
f = tf.gfile.Open(filepath)
10571109
b = f.read()
1058-
return to_unicode_ignore_erros(b)
1110+
yield to_unicode_ignore_erros(b)
1111+
1112+
def file_generator(self,
1113+
filepaths,
1114+
max_chars_per_file=None,
1115+
max_chars_total=None):
1116+
"""Read complete text of input files and yield unicode strings.
1117+
1118+
By default, one unicode string is produced per file, but this is
1119+
not guaranteed, since subclasses can override
1120+
filepath_to_unicode_strings().
10591121
1060-
def file_generator(self, tmp_dir, task_id, max_files=None):
1061-
"""Reads complete text of input files and returns as unicode.
1122+
max_chars_per_file and max_chars_total can also be specified, in which
1123+
case some strings may be truncated or dropped to limit the total
1124+
amount of output.
10621125
10631126
Args:
1064-
tmp_dir: a string
1065-
task_id: an integer less than num_shards, or "train" for training shards
1066-
max_files: an optional integer
1127+
filepaths: a list of strings
1128+
max_chars_per_file: an optional integer
1129+
max_chars_total: an optional integer
10671130
Yields:
10681131
unicode strings
10691132
"""
1070-
count = 0
1071-
if task_id == "train":
1072-
fnames = self.train_text_filenames(tmp_dir)
1073-
else:
1074-
fnames = self.text_filenames_for_task(tmp_dir, task_id)
1075-
for fname in fnames:
1133+
chars_total = 0
1134+
for fname in filepaths:
1135+
chars_this_file = 0
10761136
tf.logging.info("reading file %s" % fname)
1077-
yield self.filename_to_unicode_text(fname)
1078-
count += 1
1079-
if max_files and count == max_files:
1080-
return
1137+
for text in self.filepath_to_unicode_strings(fname):
1138+
if (max_chars_per_file and chars_this_file + len(text)
1139+
> max_chars_per_file):
1140+
text = text[:max_chars_per_file - chars_this_file]
1141+
if max_chars_total and chars_total + len(text) > max_chars_total:
1142+
text = text[:max_chars_total - chars_total]
1143+
chars_total += len(text)
1144+
chars_this_file += len(text)
1145+
if text:
1146+
yield text
1147+
if max_chars_per_file and chars_this_file >= max_chars_per_file:
1148+
break
1149+
if max_chars_total and chars_total >= max_chars_total:
1150+
break
10811151

10821152
def example_generator(self, encoder, tmp_dir, task_id):
10831153
"""Generator for examples.
@@ -1089,17 +1159,29 @@ def example_generator(self, encoder, tmp_dir, task_id):
10891159
Yields:
10901160
feature dictionaries
10911161
"""
1092-
for ftext in self.file_generator(tmp_dir, task_id):
1093-
encoded = encoder.encode(ftext)
1094-
for start_pos in xrange(0, len(encoded), self.sequence_length):
1095-
targets = encoded[start_pos:start_pos + self.sequence_length]
1096-
if len(targets) < self.sequence_length:
1097-
if self.remainder_policy == "pad":
1098-
targets += [0] * (self.sequence_length - len(targets))
1099-
else:
1100-
assert self.remainder_policy == "drop"
1101-
continue
1162+
filepaths = self.text_filepaths_for_task(tmp_dir, task_id)
1163+
if task_id >= self.num_train_shards:
1164+
# this is dev data - limit the total length.
1165+
max_chars_per_file = self.max_dev_chars // (
1166+
self.num_dev_shards * len(filepaths))
1167+
else:
1168+
max_chars_per_file = None
1169+
tokens = []
1170+
for ftext in self.file_generator(
1171+
filepaths, max_chars_per_file=max_chars_per_file):
1172+
tokens.extend(encoder.encode(ftext))
1173+
pos = 0
1174+
while pos + self.sequence_length <= len(tokens):
1175+
yield {"inputs": [0], "targets": tokens[pos:pos + self.sequence_length]}
1176+
pos += self.sequence_length
1177+
if pos > 0:
1178+
tokens = tokens[pos:]
1179+
if self.remainder_policy == "pad":
1180+
if tokens:
1181+
targets = tokens + [0] * (self.sequence_length - len(tokens))
11021182
yield {"inputs": [0], "targets": targets}
1183+
else:
1184+
assert self.remainder_policy == "drop"
11031185

11041186
@property
11051187
def remainder_policy(self):
@@ -1113,14 +1195,15 @@ def remainder_policy(self):
11131195
def prepare_to_generate(self, data_dir, tmp_dir):
11141196
"""Make sure that the data is prepared and the vocab is generated."""
11151197
self.get_or_generate_vocab(data_dir, tmp_dir)
1116-
self.train_text_filenames(tmp_dir)
1117-
self.dev_text_filenames(tmp_dir)
1198+
self.train_text_filepaths(tmp_dir)
1199+
self.dev_text_filepaths(tmp_dir)
11181200

11191201
def get_or_generate_vocab(self, data_dir, tmp_dir):
11201202
return generator_utils.get_or_generate_vocab_inner(
11211203
data_dir, self.vocab_file, self.targeted_vocab_size,
11221204
self.file_generator(
1123-
tmp_dir, task_id="train", max_files=self.max_files_for_vocab))
1205+
self.train_text_filepaths(tmp_dir),
1206+
max_chars_total=self.max_chars_for_vocab))
11241207

11251208
def generate_data(self, data_dir, tmp_dir, task_id=-1):
11261209
"""Generates training/dev data.
@@ -1147,9 +1230,9 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
11471230
generator_utils.shuffle_dataset([out_file])
11481231

11491232
@property
1150-
def max_files_for_vocab(self):
1151-
"""Number of input files to read when generating vocab."""
1152-
return 10
1233+
def max_chars_for_vocab(self):
1234+
"""Number of characters of training data to use for generating vocab."""
1235+
return 10 ** 7
11531236

11541237
@property
11551238
def target_space_id(self):
@@ -1163,6 +1246,11 @@ def num_train_shards(self):
11631246
def num_dev_shards(self):
11641247
return 1
11651248

1249+
@property
1250+
def max_dev_chars(self):
1251+
"""Limit dev set to at most this many characters (default 10M)."""
1252+
return 10 ** 7
1253+
11661254
@property
11671255
def multiprocess_generate(self):
11681256
return True

0 commit comments

Comments
 (0)