11#!/usr/bin/env python
2+ # coding=utf-8
23# Copyright 2017 The Tensor2Tensor Authors.
34#
45# Licensed under the Apache License, Version 2.0 (the "License");
@@ -62,10 +63,12 @@ flags.DEFINE_string("problem", "",
6263 "The name of the problem to generate data for." )
6364flags .DEFINE_string ("exclude_problems" , "" ,
6465 "Comma-separates list of problems to exclude." )
65- flags .DEFINE_integer ("num_shards" , 10 , "How many shards to use." )
66+ flags .DEFINE_integer ("num_shards" , 0 , "How many shards to use. Ignored for "
67+ "registered Problems." )
6668flags .DEFINE_integer ("max_cases" , 0 ,
6769 "Maximum number of cases to generate (unbounded if 0)." )
6870flags .DEFINE_integer ("random_seed" , 429459 , "Random seed to use." )
71+ flags .DEFINE_integer ("task_id" , - 1 , "For distributed data generation." )
6972flags .DEFINE_string ("t2t_usr_dir" , "" ,
7073 "Path to a Python module that will be imported. The "
7174 "__init__.py file should include the necessary imports. "
@@ -108,6 +111,10 @@ _SUPPORTED_PROBLEM_GENERATORS = {
108111 lambda : lm1b .generator (FLAGS .tmp_dir , True ),
109112 lambda : lm1b .generator (FLAGS .tmp_dir , False )
110113 ),
114+ "lm1b_characters" : (
115+ lambda : lm1b .generator (FLAGS .tmp_dir , True , characters = True ),
116+ lambda : lm1b .generator (FLAGS .tmp_dir , False , characters = True )
117+ ),
111118 "wiki_32k" : (
112119 lambda : wiki .generator (FLAGS .tmp_dir , True ),
113120 1000
@@ -246,7 +253,7 @@ def generate_data_for_problem(problem):
246253 if isinstance (dev_gen , int ):
247254 # The dev set and test sets are generated as extra shards using the
248255 # training generator. The integer specifies the number of training
249- # shards. FLAGS.num_shards is ignored.
256+ # shards. FLAGS.num_shards is ignored.
250257 num_training_shards = dev_gen
251258 tf .logging .info ("Generating data for %s." , problem )
252259 all_output_files = generator_utils .combined_data_filenames (
@@ -257,10 +264,11 @@ def generate_data_for_problem(problem):
257264 else :
258265 # usual case - train data and dev data are generated using separate
259266 # generators.
267+ num_shards = FLAGS .num_shards or 10
260268 tf .logging .info ("Generating training data for %s." , problem )
261269 train_output_files = generator_utils .train_data_filenames (
262270 problem + generator_utils .UNSHUFFLED_SUFFIX , FLAGS .data_dir ,
263- FLAGS . num_shards )
271+ num_shards )
264272 generator_utils .generate_files (training_gen (), train_output_files ,
265273 FLAGS .max_cases )
266274 tf .logging .info ("Generating development data for %s." , problem )
@@ -275,10 +283,14 @@ def generate_data_for_problem(problem):
275283
276284
277285def generate_data_for_registered_problem (problem_name ):
286+ tf .logging .info ("Generating training data for %s." , problem_name )
287+ if FLAGS .num_shards :
288+ raise ValueError ("--num_shards should not be set for registered Problem." )
278289 problem = registry .problem (problem_name )
290+ task_id = None if FLAGS .task_id < 0 else FLAGS .task_id
279291 problem .generate_data (os .path .expanduser (FLAGS .data_dir ),
280292 os .path .expanduser (FLAGS .tmp_dir ),
281- FLAGS . num_shards )
293+ task_id = task_id )
282294
283295
284296if __name__ == "__main__" :
0 commit comments