Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4f67b7b

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Move algorithmic and WMT problems to Problem class, correct summaries.
PiperOrigin-RevId: 162424062
1 parent ecacf77 commit 4f67b7b

File tree

13 files changed

+600
-332
lines changed

13 files changed

+600
-332
lines changed

tensor2tensor/bin/t2t-datagen

Lines changed: 6 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ import tempfile
3535

3636
import numpy as np
3737

38-
from tensor2tensor.data_generators import algorithmic
3938
from tensor2tensor.data_generators import algorithmic_math
4039
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
4140
from tensor2tensor.data_generators import audio
@@ -60,6 +59,8 @@ flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
6059
"Temporary storage directory.")
6160
flags.DEFINE_string("problem", "",
6261
"The name of the problem to generate data for.")
62+
flags.DEFINE_string("exclude_problems", "",
63+
"Comma-separates list of problems to exclude.")
6364
flags.DEFINE_integer("num_shards", 10, "How many shards to use.")
6465
flags.DEFINE_integer("max_cases", 0,
6566
"Maximum number of cases to generate (unbounded if 0).")
@@ -74,37 +75,6 @@ flags.DEFINE_string("t2t_usr_dir", "",
7475
# Mapping from problems that we can generate data for to their generators.
7576
# pylint: disable=g-long-lambda
7677
_SUPPORTED_PROBLEM_GENERATORS = {
77-
"algorithmic_shift_decimal40": (
78-
lambda: algorithmic.shift_generator(20, 10, 40, 100000),
79-
lambda: algorithmic.shift_generator(20, 10, 80, 10000)),
80-
"algorithmic_reverse_binary40": (
81-
lambda: algorithmic.reverse_generator(2, 40, 100000),
82-
lambda: algorithmic.reverse_generator(2, 400, 10000)),
83-
"algorithmic_reverse_decimal40": (
84-
lambda: algorithmic.reverse_generator(10, 40, 100000),
85-
lambda: algorithmic.reverse_generator(10, 400, 10000)),
86-
"algorithmic_addition_binary40": (
87-
lambda: algorithmic.addition_generator(2, 40, 100000),
88-
lambda: algorithmic.addition_generator(2, 400, 10000)),
89-
"algorithmic_addition_decimal40": (
90-
lambda: algorithmic.addition_generator(10, 40, 100000),
91-
lambda: algorithmic.addition_generator(10, 400, 10000)),
92-
"algorithmic_multiplication_binary40": (
93-
lambda: algorithmic.multiplication_generator(2, 40, 100000),
94-
lambda: algorithmic.multiplication_generator(2, 400, 10000)),
95-
"algorithmic_multiplication_decimal40": (
96-
lambda: algorithmic.multiplication_generator(10, 40, 100000),
97-
lambda: algorithmic.multiplication_generator(10, 400, 10000)),
98-
"algorithmic_reverse_nlplike_decimal8K": (
99-
lambda: algorithmic.reverse_generator_nlplike(8000, 70, 100000,
100-
10, 1.300),
101-
lambda: algorithmic.reverse_generator_nlplike(8000, 70, 10000,
102-
10, 1.300)),
103-
"algorithmic_reverse_nlplike_decimal32K": (
104-
lambda: algorithmic.reverse_generator_nlplike(32000, 70, 100000,
105-
10, 1.050),
106-
lambda: algorithmic.reverse_generator_nlplike(32000, 70, 10000,
107-
10, 1.050)),
10878
"algorithmic_algebra_inverse": (
10979
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
11080
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
@@ -124,29 +94,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
12494
2**14, 2**9),
12595
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
12696
2**14, 2**9)),
127-
"wmt_enfr_characters": (
128-
lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, True),
129-
lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, False)),
130-
"wmt_enfr_tokens_8k": (
131-
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13),
132-
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13)
133-
),
134-
"wmt_enfr_tokens_32k": (
135-
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
136-
lambda: wmt.enfr_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
137-
),
138-
"wmt_ende_characters": (
139-
lambda: wmt.ende_character_generator(FLAGS.tmp_dir, True),
140-
lambda: wmt.ende_character_generator(FLAGS.tmp_dir, False)),
14197
"wmt_ende_bpe32k": (
14298
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True),
14399
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)),
144-
"wmt_zhen_tokens_32k": (
145-
lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True,
146-
2**15, 2**15),
147-
lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, False,
148-
2**15, 2**15)
149-
),
150100
"lm1b_32k": (
151101
lambda: lm1b.generator(FLAGS.tmp_dir, True),
152102
lambda: lm1b.generator(FLAGS.tmp_dir, False)
@@ -285,6 +235,9 @@ def main(_):
285235
# Calculate the list of problems to generate.
286236
problems = sorted(
287237
list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
238+
for exclude in FLAGS.exclude_problems.split(","):
239+
if exclude:
240+
problems = [p for p in problems if exclude not in p]
288241
if FLAGS.problem and FLAGS.problem[-1] == "*":
289242
problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
290243
elif FLAGS.problem:
@@ -364,7 +317,7 @@ def generate_data_for_problem(problem):
364317

365318
def generate_data_for_registered_problem(problem_name):
366319
problem = registry.problem(problem_name)
367-
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
320+
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir, FLAGS.num_shards)
368321

369322

370323
if __name__ == "__main__":

0 commit comments

Comments
 (0)