@@ -35,7 +35,6 @@ import tempfile
3535
3636import numpy as np
3737
38- from tensor2tensor .data_generators import algorithmic
3938from tensor2tensor .data_generators import algorithmic_math
4039from tensor2tensor .data_generators import all_problems # pylint: disable=unused-import
4140from tensor2tensor .data_generators import audio
@@ -60,6 +59,8 @@ flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen",
6059 "Temporary storage directory." )
6160flags .DEFINE_string ("problem" , "" ,
6261 "The name of the problem to generate data for." )
62+ flags .DEFINE_string ("exclude_problems" , "" ,
63+ "Comma-separates list of problems to exclude." )
6364flags .DEFINE_integer ("num_shards" , 10 , "How many shards to use." )
6465flags .DEFINE_integer ("max_cases" , 0 ,
6566 "Maximum number of cases to generate (unbounded if 0)." )
@@ -74,37 +75,6 @@ flags.DEFINE_string("t2t_usr_dir", "",
7475# Mapping from problems that we can generate data for to their generators.
7576# pylint: disable=g-long-lambda
7677_SUPPORTED_PROBLEM_GENERATORS = {
77- "algorithmic_shift_decimal40" : (
78- lambda : algorithmic .shift_generator (20 , 10 , 40 , 100000 ),
79- lambda : algorithmic .shift_generator (20 , 10 , 80 , 10000 )),
80- "algorithmic_reverse_binary40" : (
81- lambda : algorithmic .reverse_generator (2 , 40 , 100000 ),
82- lambda : algorithmic .reverse_generator (2 , 400 , 10000 )),
83- "algorithmic_reverse_decimal40" : (
84- lambda : algorithmic .reverse_generator (10 , 40 , 100000 ),
85- lambda : algorithmic .reverse_generator (10 , 400 , 10000 )),
86- "algorithmic_addition_binary40" : (
87- lambda : algorithmic .addition_generator (2 , 40 , 100000 ),
88- lambda : algorithmic .addition_generator (2 , 400 , 10000 )),
89- "algorithmic_addition_decimal40" : (
90- lambda : algorithmic .addition_generator (10 , 40 , 100000 ),
91- lambda : algorithmic .addition_generator (10 , 400 , 10000 )),
92- "algorithmic_multiplication_binary40" : (
93- lambda : algorithmic .multiplication_generator (2 , 40 , 100000 ),
94- lambda : algorithmic .multiplication_generator (2 , 400 , 10000 )),
95- "algorithmic_multiplication_decimal40" : (
96- lambda : algorithmic .multiplication_generator (10 , 40 , 100000 ),
97- lambda : algorithmic .multiplication_generator (10 , 400 , 10000 )),
98- "algorithmic_reverse_nlplike_decimal8K" : (
99- lambda : algorithmic .reverse_generator_nlplike (8000 , 70 , 100000 ,
100- 10 , 1.300 ),
101- lambda : algorithmic .reverse_generator_nlplike (8000 , 70 , 10000 ,
102- 10 , 1.300 )),
103- "algorithmic_reverse_nlplike_decimal32K" : (
104- lambda : algorithmic .reverse_generator_nlplike (32000 , 70 , 100000 ,
105- 10 , 1.050 ),
106- lambda : algorithmic .reverse_generator_nlplike (32000 , 70 , 10000 ,
107- 10 , 1.050 )),
10878 "algorithmic_algebra_inverse" : (
10979 lambda : algorithmic_math .algebra_inverse (26 , 0 , 2 , 100000 ),
11080 lambda : algorithmic_math .algebra_inverse (26 , 3 , 3 , 10000 )),
@@ -124,29 +94,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
12494 2 ** 14 , 2 ** 9 ),
12595 lambda : wsj_parsing .parsing_token_generator (FLAGS .tmp_dir , False ,
12696 2 ** 14 , 2 ** 9 )),
127- "wmt_enfr_characters" : (
128- lambda : wmt .enfr_character_generator (FLAGS .tmp_dir , True ),
129- lambda : wmt .enfr_character_generator (FLAGS .tmp_dir , False )),
130- "wmt_enfr_tokens_8k" : (
131- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 13 ),
132- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 13 )
133- ),
134- "wmt_enfr_tokens_32k" : (
135- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , True , 2 ** 15 ),
136- lambda : wmt .enfr_wordpiece_token_generator (FLAGS .tmp_dir , False , 2 ** 15 )
137- ),
138- "wmt_ende_characters" : (
139- lambda : wmt .ende_character_generator (FLAGS .tmp_dir , True ),
140- lambda : wmt .ende_character_generator (FLAGS .tmp_dir , False )),
14197 "wmt_ende_bpe32k" : (
14298 lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , True ),
14399 lambda : wmt .ende_bpe_token_generator (FLAGS .tmp_dir , False )),
144- "wmt_zhen_tokens_32k" : (
145- lambda : wmt .zhen_wordpiece_token_generator (FLAGS .tmp_dir , True ,
146- 2 ** 15 , 2 ** 15 ),
147- lambda : wmt .zhen_wordpiece_token_generator (FLAGS .tmp_dir , False ,
148- 2 ** 15 , 2 ** 15 )
149- ),
150100 "lm1b_32k" : (
151101 lambda : lm1b .generator (FLAGS .tmp_dir , True ),
152102 lambda : lm1b .generator (FLAGS .tmp_dir , False )
@@ -285,6 +235,9 @@ def main(_):
285235 # Calculate the list of problems to generate.
286236 problems = sorted (
287237 list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_problems ())
238+ for exclude in FLAGS .exclude_problems .split ("," ):
239+ if exclude :
240+ problems = [p for p in problems if exclude not in p ]
288241 if FLAGS .problem and FLAGS .problem [- 1 ] == "*" :
289242 problems = [p for p in problems if p .startswith (FLAGS .problem [:- 1 ])]
290243 elif FLAGS .problem :
@@ -364,7 +317,7 @@ def generate_data_for_problem(problem):
364317
365318def generate_data_for_registered_problem (problem_name ):
366319 problem = registry .problem (problem_name )
367- problem .generate_data (FLAGS .data_dir , FLAGS .tmp_dir )
320+ problem .generate_data (FLAGS .data_dir , FLAGS .tmp_dir , FLAGS . num_shards )
368321
369322
370323if __name__ == "__main__" :
0 commit comments