@@ -35,7 +35,6 @@ import tempfile
3535
3636import numpy as np
3737
38- from tensor2tensor .data_generators import algorithmic
3938from tensor2tensor .data_generators import algorithmic_math
4039from tensor2tensor .data_generators import all_problems # pylint: disable=unused-import
4140from tensor2tensor .data_generators import audio
@@ -60,52 +59,22 @@ flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
6059 "Temporary storage directory." )
6160flags .DEFINE_string ("problem" , "" ,
6261 "The name of the problem to generate data for." )
62+ flags .DEFINE_string ("exclude_problems" , "" ,
63+ "Comma-separates list of problems to exclude." )
6364flags .DEFINE_integer ("num_shards" , 10 , "How many shards to use." )
6465flags .DEFINE_integer ("max_cases" , 0 ,
6566 "Maximum number of cases to generate (unbounded if 0)." )
6667flags .DEFINE_integer ("random_seed" , 429459 , "Random seed to use." )
67-
6868flags .DEFINE_string ("t2t_usr_dir" , "" ,
6969 "Path to a Python module that will be imported. The "
7070 "__init__.py file should include the necessary imports. "
7171 "The imported files should contain registrations, "
72- "e.g. @registry.register_model calls, that will then be "
73- "available to the t2t-datagen." )
72+ "e.g. @registry.register_problem calls, that will then be "
73+ "available to t2t-datagen." )
7474
7575# Mapping from problems that we can generate data for to their generators.
7676# pylint: disable=g-long-lambda
7777_SUPPORTED_PROBLEM_GENERATORS = {
78- "algorithmic_shift_decimal40" : (
79- lambda : algorithmic .shift_generator (20 , 10 , 40 , 100000 ),
80- lambda : algorithmic .shift_generator (20 , 10 , 80 , 10000 )),
81- "algorithmic_reverse_binary40" : (
82- lambda : algorithmic .reverse_generator (2 , 40 , 100000 ),
83- lambda : algorithmic .reverse_generator (2 , 400 , 10000 )),
84- "algorithmic_reverse_decimal40" : (
85- lambda : algorithmic .reverse_generator (10 , 40 , 100000 ),
86- lambda : algorithmic .reverse_generator (10 , 400 , 10000 )),
87- "algorithmic_addition_binary40" : (
88- lambda : algorithmic .addition_generator (2 , 40 , 100000 ),
89- lambda : algorithmic .addition_generator (2 , 400 , 10000 )),
90- "algorithmic_addition_decimal40" : (
91- lambda : algorithmic .addition_generator (10 , 40 , 100000 ),
92- lambda : algorithmic .addition_generator (10 , 400 , 10000 )),
93- "algorithmic_multiplication_binary40" : (
94- lambda : algorithmic .multiplication_generator (2 , 40 , 100000 ),
95- lambda : algorithmic .multiplication_generator (2 , 400 , 10000 )),
96- "algorithmic_multiplication_decimal40" : (
97- lambda : algorithmic .multiplication_generator (10 , 40 , 100000 ),
98- lambda : algorithmic .multiplication_generator (10 , 400 , 10000 )),
99- "algorithmic_reverse_nlplike_decimal8K" : (
100- lambda : algorithmic .reverse_generator_nlplike (8000 , 70 , 100000 ,
101- 10 , 1.300 ),
102- lambda : algorithmic .reverse_generator_nlplike (8000 , 70 , 10000 ,
103- 10 , 1.300 )),
104- "algorithmic_reverse_nlplike_decimal32K" : (
105- lambda : algorithmic .reverse_generator_nlplike (32000 , 70 , 100000 ,
106- 10 , 1.050 ),
107- lambda : algorithmic .reverse_generator_nlplike (32000 , 70 , 10000 ,
108- 10 , 1.050 )),
10978 "algorithmic_algebra_inverse" : (
11079 lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
11180 lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 )),
@@ -125,29 +94,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
12594 2 ** 14 , 2 ** 9 ),
12695 lambda : wsj_parsing .parsing_token_generator (FLAGS .tmp_dir , False ,
12796 2 ** 14 , 2 ** 9 )),
128- "wmt_enfr_characters" : (
129- lambda : wmt .enfr_character_generator (FLAGS .tmp_dir , True ),
130- lambda : wmt .enfr_character_generator (FLAGS .tmp_dir , False )),
131- "wmt_enfr_tokens_8k" : (
132- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 13 ),
133- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 13 )
134- ),
135- "wmt_enfr_tokens_32k" : (
136- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
137- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 15 )
138- ),
139- "wmt_ende_characters" : (
140- lambda : wmt .ende_character_generator (FLAGS .tmp_dir , True ),
141- lambda : wmt .ende_character_generator (FLAGS .tmp_dir , False )),
14297 "wmt_ende_bpe32k" : (
14398 lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , True ),
14499 lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , False )),
145- "wmt_zhen_tokens_32k" : (
146- lambda : wmt .zhen_wordpiece_token_generator (FLAGS .tmp_dir , True ,
147- 2 ** 15 , 2 ** 15 ),
148- lambda : wmt .zhen_wordpiece_token_generator (FLAGS .tmp_dir , False ,
149- 2 ** 15 , 2 ** 15 )
150- ),
151100 "lm1b_32k" : (
152101 lambda : lm1b .generator (FLAGS .tmp_dir , True ),
153102 lambda : lm1b .generator (FLAGS .tmp_dir , False )
@@ -286,6 +235,9 @@ def main(_):
286235 # Calculate the list of problems to generate.
287236 problems = sorted (
288237 list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_problems ())
238+ for exclude in FLAGS .exclude_problems .split ("," ):
239+ if exclude :
240+ problems = [p for p in problems if exclude not in p ]
289241 if FLAGS .problem and FLAGS .problem [- 1 ] == "*" :
290242 problems = [p for p in problems if p .startswith (FLAGS .problem [:- 1 ])]
291243 elif FLAGS .problem :
0 commit comments