Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ca7b045

Browse files
author
Ryan Sepassi
committed
Problem base class and registry
PiperOrigin-RevId: 161892788
1 parent 1e18474 commit ca7b045

File tree

12 files changed

+546
-97
lines changed

12 files changed

+546
-97
lines changed

tensor2tensor/bin/t2t-datagen

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
2424
yields for each training example a dictionary mapping string feature names to
2525
lists of {string, int, float}. The generator will be run once for each mode.
2626
"""
27+
from __future__ import absolute_import
28+
from __future__ import division
29+
from __future__ import print_function
2730

2831
import random
2932
import tempfile
@@ -34,6 +37,7 @@ import numpy as np
3437

3538
from tensor2tensor.data_generators import algorithmic
3639
from tensor2tensor.data_generators import algorithmic_math
40+
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
3741
from tensor2tensor.data_generators import audio
3842
from tensor2tensor.data_generators import generator_utils
3943
from tensor2tensor.data_generators import image
@@ -43,6 +47,7 @@ from tensor2tensor.data_generators import snli
4347
from tensor2tensor.data_generators import wiki
4448
from tensor2tensor.data_generators import wmt
4549
from tensor2tensor.data_generators import wsj_parsing
50+
from tensor2tensor.utils import registry
4651

4752
import tensorflow as tf
4853

@@ -62,12 +67,6 @@ flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
6267
# Mapping from problems that we can generate data for to their generators.
6368
# pylint: disable=g-long-lambda
6469
_SUPPORTED_PROBLEM_GENERATORS = {
65-
"algorithmic_identity_binary40": (
66-
lambda: algorithmic.identity_generator(2, 40, 100000),
67-
lambda: algorithmic.identity_generator(2, 400, 10000)),
68-
"algorithmic_identity_decimal40": (
69-
lambda: algorithmic.identity_generator(10, 40, 100000),
70-
lambda: algorithmic.identity_generator(10, 400, 10000)),
7170
"algorithmic_shift_decimal40": (
7271
lambda: algorithmic.shift_generator(20, 10, 40, 100000),
7372
lambda: algorithmic.shift_generator(20, 10, 80, 10000)),
@@ -294,8 +293,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
294293

295294
# pylint: enable=g-long-lambda
296295

297-
UNSHUFFLED_SUFFIX = "-unshuffled"
298-
299296

300297
def set_random_seed():
301298
"""Set the random seed from flag everywhere."""
@@ -308,13 +305,15 @@ def main(_):
308305
tf.logging.set_verbosity(tf.logging.INFO)
309306

310307
# Calculate the list of problems to generate.
311-
problems = list(sorted(_SUPPORTED_PROBLEM_GENERATORS))
308+
problems = sorted(
309+
list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
312310
if FLAGS.problem and FLAGS.problem[-1] == "*":
313311
problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
314312
elif FLAGS.problem:
315313
problems = [p for p in problems if p == FLAGS.problem]
316314
else:
317315
problems = []
316+
318317
# Remove TIMIT if paths are not given.
319318
if not FLAGS.timit_paths:
320319
problems = [p for p in problems if "timit" not in p]
@@ -326,7 +325,8 @@ def main(_):
326325
problems = [p for p in problems if "ende_bpe" not in p]
327326

328327
if not problems:
329-
problems_str = "\n * ".join(sorted(_SUPPORTED_PROBLEM_GENERATORS))
328+
problems_str = "\n * ".join(
329+
sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
330330
error_msg = ("You must specify one of the supported problems to "
331331
"generate data for:\n * " + problems_str + "\n")
332332
error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
@@ -343,40 +343,50 @@ def main(_):
343343
for problem in problems:
344344
set_random_seed()
345345

346-
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]
347-
348-
if isinstance(dev_gen, int):
349-
# The dev set and test sets are generated as extra shards using the
350-
# training generator. The integer specifies the number of training
351-
# shards. FLAGS.num_shards is ignored.
352-
num_training_shards = dev_gen
353-
tf.logging.info("Generating data for %s.", problem)
354-
all_output_files = generator_utils.combined_data_filenames(
355-
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_training_shards)
356-
generator_utils.generate_files(
357-
training_gen(), all_output_files, FLAGS.max_cases)
346+
if problem in _SUPPORTED_PROBLEM_GENERATORS:
347+
generate_data_for_problem(problem)
358348
else:
359-
# usual case - train data and dev data are generated using separate
360-
# generators.
361-
tf.logging.info("Generating training data for %s.", problem)
362-
train_output_files = generator_utils.train_data_filenames(
363-
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, FLAGS.num_shards)
364-
generator_utils.generate_files(
365-
training_gen(), train_output_files, FLAGS.max_cases)
366-
tf.logging.info("Generating development data for %s.", problem)
367-
dev_shards = 10 if "coco" in problem else 1
368-
dev_output_files = generator_utils.dev_data_filenames(
369-
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
370-
generator_utils.generate_files(dev_gen(), dev_output_files)
371-
all_output_files = train_output_files + dev_output_files
349+
generate_data_for_registered_problem(problem)
350+
351+
352+
def generate_data_for_problem(problem):
353+
"""Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
354+
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]
355+
356+
if isinstance(dev_gen, int):
357+
# The dev set and test sets are generated as extra shards using the
358+
# training generator. The integer specifies the number of training
359+
# shards. FLAGS.num_shards is ignored.
360+
num_training_shards = dev_gen
361+
tf.logging.info("Generating data for %s.", problem)
362+
all_output_files = generator_utils.combined_data_filenames(
363+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
364+
num_training_shards)
365+
generator_utils.generate_files(training_gen(), all_output_files,
366+
FLAGS.max_cases)
367+
else:
368+
# usual case - train data and dev data are generated using separate
369+
# generators.
370+
tf.logging.info("Generating training data for %s.", problem)
371+
train_output_files = generator_utils.train_data_filenames(
372+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
373+
FLAGS.num_shards)
374+
generator_utils.generate_files(training_gen(), train_output_files,
375+
FLAGS.max_cases)
376+
tf.logging.info("Generating development data for %s.", problem)
377+
dev_shards = 10 if "coco" in problem else 1
378+
dev_output_files = generator_utils.dev_data_filenames(
379+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
380+
generator_utils.generate_files(dev_gen(), dev_output_files)
381+
all_output_files = train_output_files + dev_output_files
382+
383+
tf.logging.info("Shuffling data...")
384+
generator_utils.shuffle_dataset(all_output_files)
385+
372386

373-
tf.logging.info("Shuffling data...")
374-
for fname in all_output_files:
375-
records = generator_utils.read_records(fname)
376-
random.shuffle(records)
377-
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
378-
generator_utils.write_records(records, out_fname)
379-
tf.gfile.Remove(fname)
387+
def generate_data_for_registered_problem(problem_name):
388+
problem = registry.problem(problem_name)
389+
problem.generate_data(FLAGS.data_dir)
380390

381391

382392
if __name__ == "__main__":

tensor2tensor/data_generators/algorithmic.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,50 @@
2323

2424
from six.moves import xrange # pylint: disable=redefined-builtin
2525

26+
from tensor2tensor.data_generators import generator_utils as utils
27+
from tensor2tensor.data_generators import problem
28+
from tensor2tensor.utils import registry
29+
30+
31+
@registry.register_problem
32+
class AlgorithmicIdentityBinary40(problem.Problem):
33+
"""Problem spec for algorithmic binary identity task."""
34+
35+
@property
36+
def num_symbols(self):
37+
return 2
38+
39+
def generate_data(self, data_dir):
40+
utils.generate_files(
41+
identity_generator(self.num_symbols, 40, 100000),
42+
self.training_filepaths(data_dir, 100))
43+
utils.generate_files(
44+
identity_generator(self.num_symbols, 400, 10000),
45+
self.dev_filepaths(data_dir, 1))
46+
47+
def hparams(self, defaults, unused_model_hparams):
48+
p = defaults
49+
vocab_size = self.num_symbols + self._encoders["inputs"].num_reserved_ids
50+
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)}
51+
p.target_modality = (registry.Modalities.SYMBOL, vocab_size)
52+
p.input_space_id = problem.SpaceID.DIGIT_0
53+
p.target_space_id = problem.SpaceID.DIGIT_1
54+
55+
56+
@registry.register_problem
57+
class AlgorithmicIdentityDecimal40(AlgorithmicIdentityBinary40):
58+
"""Problem spec for algorithmic decimal identity task."""
59+
60+
@property
61+
def num_symbols(self):
62+
return 10
63+
2664

2765
def identity_generator(nbr_symbols, max_length, nbr_cases):
2866
"""Generator for the identity (copy) task on sequences of symbols.
2967
3068
The length of the sequence is drawn uniformly at random from [1, max_length]
31-
and then symbols are drawn uniformly at random from [2, nbr_symbols] until
69+
and then symbols are drawn uniformly at random from [2, nbr_symbols + 2) until
3270
nbr_cases sequences have been produced.
3371
3472
Args:
@@ -66,8 +104,10 @@ def shift_generator(nbr_symbols, shift, max_length, nbr_cases):
66104
for _ in xrange(nbr_cases):
67105
l = np.random.randint(max_length) + 1
68106
inputs = [np.random.randint(nbr_symbols - shift) + 2 for _ in xrange(l)]
69-
yield {"inputs": inputs,
70-
"targets": [i + shift for i in inputs] + [1]} # [1] for EOS
107+
yield {
108+
"inputs": inputs,
109+
"targets": [i + shift for i in inputs] + [1]
110+
} # [1] for EOS
71111

72112

73113
def reverse_generator(nbr_symbols, max_length, nbr_cases):
@@ -89,8 +129,10 @@ def reverse_generator(nbr_symbols, max_length, nbr_cases):
89129
for _ in xrange(nbr_cases):
90130
l = np.random.randint(max_length) + 1
91131
inputs = [np.random.randint(nbr_symbols) + 2 for _ in xrange(l)]
92-
yield {"inputs": inputs,
93-
"targets": list(reversed(inputs)) + [1]} # [1] for EOS
132+
yield {
133+
"inputs": inputs,
134+
"targets": list(reversed(inputs)) + [1]
135+
} # [1] for EOS
94136

95137

96138
def zipf_distribution(nbr_symbols, alpha):
@@ -106,7 +148,7 @@ def zipf_distribution(nbr_symbols, alpha):
106148
distr_map: list of float, Zipf's distribution over nbr_symbols.
107149
108150
"""
109-
tmp = np.power(np.arange(1, nbr_symbols+1), -alpha)
151+
tmp = np.power(np.arange(1, nbr_symbols + 1), -alpha)
110152
zeta = np.r_[0.0, np.cumsum(tmp)]
111153
return [x / zeta[-1] for x in zeta]
112154

@@ -128,11 +170,14 @@ def zipf_random_sample(distr_map, sample_len):
128170
# we have made a sanity check to overcome this issue. On the other hand,
129171
# t+1 is enough from saving us to generate PAD(0) and EOS(1) which are
130172
# reservated symbols.
131-
return [t+1 if t > 0 else t+2 for t in np.searchsorted(distr_map, u)]
173+
return [t + 1 if t > 0 else t + 2 for t in np.searchsorted(distr_map, u)]
132174

133175

134-
def reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases,
135-
scale_std_dev=100, alpha=1.5):
176+
def reverse_generator_nlplike(nbr_symbols,
177+
max_length,
178+
nbr_cases,
179+
scale_std_dev=100,
180+
alpha=1.5):
136181
"""Generator for the reversing nlp-like task on sequences of symbols.
137182
138183
The length of the sequence is drawn from a Gaussian(Normal) distribution
@@ -157,10 +202,12 @@ def reverse_generator_nlplike(nbr_symbols, max_length, nbr_cases,
157202
std_dev = max_length / scale_std_dev
158203
distr_map = zipf_distribution(nbr_symbols, alpha)
159204
for _ in xrange(nbr_cases):
160-
l = int(abs(np.random.normal(loc=max_length/2, scale=std_dev)) + 1)
205+
l = int(abs(np.random.normal(loc=max_length / 2, scale=std_dev)) + 1)
161206
inputs = zipf_random_sample(distr_map, l)
162-
yield {"inputs": inputs,
163-
"targets": list(reversed(inputs)) + [1]} # [1] for EOS
207+
yield {
208+
"inputs": inputs,
209+
"targets": list(reversed(inputs)) + [1]
210+
} # [1] for EOS
164211

165212

166213
def lower_endian_to_number(l, base):
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2017 The Tensor2Tensor Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Imports for problem modules."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
# pylint: disable=unused-import
21+
from tensor2tensor.data_generators import algorithmic
22+
from tensor2tensor.data_generators import algorithmic_math
23+
from tensor2tensor.data_generators import audio
24+
from tensor2tensor.data_generators import image
25+
from tensor2tensor.data_generators import lm1b
26+
from tensor2tensor.data_generators import ptb
27+
from tensor2tensor.data_generators import snli
28+
from tensor2tensor.data_generators import wiki
29+
from tensor2tensor.data_generators import wmt
30+
from tensor2tensor.data_generators import wsj_parsing
31+
# pylint: enable=unused-import

0 commit comments

Comments
 (0)