@@ -24,6 +24,9 @@ takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
2424yields for each training example a dictionary mapping string feature names to
2525lists of {string, int, float}. The generator will be run once for each mode.
2626"""
27+ from __future__ import absolute_import
28+ from __future__ import division
29+ from __future__ import print_function
2730
2831import random
2932import tempfile
@@ -34,6 +37,7 @@ import numpy as np
3437
3538from tensor2tensor .data_generators import algorithmic
3639from tensor2tensor .data_generators import algorithmic_math
40+ from tensor2tensor .data_generators import all_problems # pylint: disable=unused-import
3741from tensor2tensor .data_generators import audio
3842from tensor2tensor .data_generators import generator_utils
3943from tensor2tensor .data_generators import image
@@ -43,6 +47,7 @@ from tensor2tensor.data_generators import snli
4347from tensor2tensor .data_generators import wiki
4448from tensor2tensor .data_generators import wmt
4549from tensor2tensor .data_generators import wsj_parsing
50+ from tensor2tensor .utils import registry
4651
4752import tensorflow as tf
4853
@@ -62,12 +67,6 @@ flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
6267# Mapping from problems that we can generate data for to their generators.
6368# pylint: disable=g-long-lambda
6469_SUPPORTED_PROBLEM_GENERATORS = {
65- "algorithmic_identity_binary40" : (
66- lambda : algorithmic .identity_generator (2 , 40 , 100000 ),
67- lambda : algorithmic .identity_generator (2 , 400 , 10000 )),
68- "algorithmic_identity_decimal40" : (
69- lambda : algorithmic .identity_generator (10 , 40 , 100000 ),
70- lambda : algorithmic .identity_generator (10 , 400 , 10000 )),
7170 "algorithmic_shift_decimal40" : (
7271 lambda : algorithmic .shift_generator (20 , 10 , 40 , 100000 ),
7372 lambda : algorithmic .shift_generator (20 , 10 , 80 , 10000 )),
@@ -294,8 +293,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
294293
295294# pylint: enable=g-long-lambda
296295
297- UNSHUFFLED_SUFFIX = "-unshuffled"
298-
299296
300297def set_random_seed ():
301298 """Set the random seed from flag everywhere."""
@@ -308,13 +305,15 @@ def main(_):
308305 tf .logging .set_verbosity (tf .logging .INFO )
309306
310307 # Calculate the list of problems to generate.
311- problems = list (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
308+ problems = sorted (
309+ list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_problems ())
312310 if FLAGS .problem and FLAGS .problem [- 1 ] == "*" :
313311 problems = [p for p in problems if p .startswith (FLAGS .problem [:- 1 ])]
314312 elif FLAGS .problem :
315313 problems = [p for p in problems if p == FLAGS .problem ]
316314 else :
317315 problems = []
316+
318317 # Remove TIMIT if paths are not given.
319318 if not FLAGS .timit_paths :
320319 problems = [p for p in problems if "timit" not in p ]
@@ -326,7 +325,8 @@ def main(_):
326325 problems = [p for p in problems if "ende_bpe" not in p ]
327326
328327 if not problems :
329- problems_str = "\n * " .join (sorted (_SUPPORTED_PROBLEM_GENERATORS ))
328+ problems_str = "\n * " .join (
329+ sorted (list (_SUPPORTED_PROBLEM_GENERATORS ) + registry .list_problems ()))
330330 error_msg = ("You must specify one of the supported problems to "
331331 "generate data for:\n * " + problems_str + "\n " )
332332 error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
@@ -343,40 +343,50 @@ def main(_):
343343 for problem in problems :
344344 set_random_seed ()
345345
346- training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [problem ]
347-
348- if isinstance (dev_gen , int ):
349- # The dev set and test sets are generated as extra shards using the
350- # training generator. The integer specifies the number of training
351- # shards. FLAGS.num_shards is ignored.
352- num_training_shards = dev_gen
353- tf .logging .info ("Generating data for %s." , problem )
354- all_output_files = generator_utils .combined_data_filenames (
355- problem + UNSHUFFLED_SUFFIX , FLAGS .data_dir , num_training_shards )
356- generator_utils .generate_files (
357- training_gen (), all_output_files , FLAGS .max_cases )
346+ if problem in _SUPPORTED_PROBLEM_GENERATORS :
347+ generate_data_for_problem (problem )
358348 else :
359- # usual case - train data and dev data are generated using separate
360- # generators.
361- tf .logging .info ("Generating training data for %s." , problem )
362- train_output_files = generator_utils .train_data_filenames (
363- problem + UNSHUFFLED_SUFFIX , FLAGS .data_dir , FLAGS .num_shards )
364- generator_utils .generate_files (
365- training_gen (), train_output_files , FLAGS .max_cases )
366- tf .logging .info ("Generating development data for %s." , problem )
367- dev_shards = 10 if "coco" in problem else 1
368- dev_output_files = generator_utils .dev_data_filenames (
369- problem + UNSHUFFLED_SUFFIX , FLAGS .data_dir , dev_shards )
370- generator_utils .generate_files (dev_gen (), dev_output_files )
371- all_output_files = train_output_files + dev_output_files
349+ generate_data_for_registered_problem (problem )
350+
351+
352+ def generate_data_for_problem (problem ):
353+ """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
354+ training_gen , dev_gen = _SUPPORTED_PROBLEM_GENERATORS [problem ]
355+
356+ if isinstance (dev_gen , int ):
357+ # The dev set and test sets are generated as extra shards using the
358+ # training generator. The integer specifies the number of training
359+ # shards. FLAGS.num_shards is ignored.
360+ num_training_shards = dev_gen
361+ tf .logging .info ("Generating data for %s." , problem )
362+ all_output_files = generator_utils .combined_data_filenames (
363+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
364+ num_training_shards )
365+ generator_utils .generate_files (training_gen (), all_output_files ,
366+ FLAGS .max_cases )
367+ else :
368+ # usual case - train data and dev data are generated using separate
369+ # generators.
370+ tf .logging .info ("Generating training data for %s." , problem )
371+ train_output_files = generator_utils .train_data_filenames (
372+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
373+ FLAGS .num_shards )
374+ generator_utils .generate_files (training_gen (), train_output_files ,
375+ FLAGS .max_cases )
376+ tf .logging .info ("Generating development data for %s." , problem )
377+ dev_shards = 10 if "coco" in problem else 1
378+ dev_output_files = generator_utils .dev_data_filenames (
379+ problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir , dev_shards )
380+ generator_utils .generate_files (dev_gen (), dev_output_files )
381+ all_output_files = train_output_files + dev_output_files
382+
383+ tf .logging .info ("Shuffling data..." )
384+ generator_utils .shuffle_dataset (all_output_files )
385+
372386
373- tf .logging .info ("Shuffling data..." )
374- for fname in all_output_files :
375- records = generator_utils .read_records (fname )
376- random .shuffle (records )
377- out_fname = fname .replace (UNSHUFFLED_SUFFIX , "" )
378- generator_utils .write_records (records , out_fname )
379- tf .gfile .Remove (fname )
387+ def generate_data_for_registered_problem (problem_name ):
388+ problem = registry .problem (problem_name )
389+ problem .generate_data (FLAGS .data_dir )
380390
381391
382392if __name__ == "__main__" :
0 commit comments