@@ -102,6 +102,7 @@ def default_model_hparams():
102102 max_input_seq_length = 0 ,
103103 max_target_seq_length = 0 ,
104104 prepend_mode = "none" ,
105+ split_to_length = 0 ,
105106 data_dir = None )
106107
107108
@@ -117,6 +118,12 @@ def preprocess_example_common(example, hparams, mode):
117118 else :
118119 example ["targets" ] = tf .concat (
119120 [example ["inputs" ], [0 ], example ["targets" ]], 0 )
121+ if hparams .split_to_length :
122+ example ["targets" ] = tf .reshape (
123+ example ["targets" ], [- 1 , hparams .split_to_length , 1 , 1 ])
124+ if len (example ) != 1 :
125+ raise ValueError ("split_to_length only works for LM problems" )
126+ return tf .data .Dataset .from_tensor_slices (example )
120127 return example
121128
122129
@@ -232,7 +239,29 @@ def max_length(self, model_hparams):
232239 Returns:
233240 an integer
234241 """
235- return model_hparams .max_length or model_hparams .batch_size
242+ return (
243+ model_hparams .split_to_length or
244+ model_hparams .max_length or
245+ model_hparams .batch_size )
246+
247+ @property
248+ def batch_size_means_tokens (self ):
249+ """Do we specify hparams.batch_size in tokens per datashard per batch.
250+
251+ This is generally done for text problems.
252+
253+ If False, we assume that batch sizes are specified in examples per
254+ datashard per batch.
255+
256+ TODO(noam): we should be more explicit and replace the hyperparameter
257+ batch size with two hyperparameters:
258+ hparams.examples_per_batch_per_datashard
259+ hparams.tokens_per_batch_per_datashard
260+
261+ Returns:
262+ a boolean
263+ """
264+ return False
236265
237266 def dataset_filename (self ):
238267 return self .name
@@ -620,23 +649,39 @@ def define_shapes(example):
620649 if is_training :
621650 dataset = dataset .repeat (None )
622651
652+ if self .batch_size_means_tokens :
653+ batch_size_means_tokens = True
654+ else :
655+ if _are_shapes_fully_defined (dataset .output_shapes ):
656+ batch_size_means_tokens = False
657+ else :
658+ tf .logging .warning (
659+ "Shapes are not fully defined. Assuming batch_size means tokens. "
660+ "You should probably override batch_size_means_tokens() "
661+ "in your problem subclass" )
662+ batch_size_means_tokens = True
663+
623664 # Batching
624- if _are_shapes_fully_defined ( dataset . output_shapes ) :
625- # Static shape features (e.g. images)
665+ if not batch_size_means_tokens :
666+ # Batch size means examples per datashard.
626667 if config and config .use_tpu :
668+ # on TPU, we use params["batch_size"], which specifies the number of
669+ # examples across all datashards
627670 tpu_batch_size = params ["batch_size" ]
628671 dataset = dataset .apply (
629672 tf .contrib .data .batch_and_drop_remainder (tpu_batch_size ))
630673 else :
631674 num_shards = (config and config .data_parallelism .n ) or 1
632675 dataset = dataset .batch (hparams .batch_size * num_shards )
633676 else :
634- # Variable length features
677+ # batch_size means tokens per datashard
635678 if config and config .use_tpu :
636679 # On TPU, pad to max_length
637680 dataset = dataset .filter (tpu_valid_size )
638681 padded_shapes = _fill_shape_nones (
639682 dataset .output_shapes , none_filler = max_length )
683+ # on TPU, we use params["batch_size"], which specifies the number of
684+ # examples across all datashards
640685 dataset = dataset .apply (
641686 tf .contrib .data .padded_batch_and_drop_remainder (
642687 params ["batch_size" ], padded_shapes ))
@@ -648,6 +693,7 @@ def define_shapes(example):
648693 shard_multiplier = (config and config .data_parallelism .n ) or 1 ,
649694 length_multiplier = self .get_hparams ().batch_size_multiplier )
650695 if hparams .use_fixed_batch_size :
696+ # Here batch_size really means examples per datashard.
651697 batching_scheme ["batch_sizes" ] = [hparams .batch_size ]
652698 batching_scheme ["boundaries" ] = []
653699 dataset = data_reader .bucket_by_sequence_length (
@@ -818,6 +864,10 @@ def is_character_level(self):
818864 def targeted_vocab_size (self ):
819865 raise NotImplementedError () # Not needed if self.is_character_level.
820866
867+ @property
868+ def batch_size_means_tokens (self ):
869+ return True
870+
821871 def generator (self , data_dir , tmp_dir , is_training ):
822872 """Generator for the training and evaluation data.
823873
@@ -981,14 +1031,14 @@ class ChoppedTextProblem(Text2TextProblem):
9811031 """Tokenize and chop text files into fixed-length language-modeling examples.
9821032
9831033 The input data is a set of text files, as specified by
984- self.train_text_filenames () and self.dev_text_filenames ().
1034+ self.train_text_filepaths () and self.dev_text_filepaths ().
9851035
9861036 The text is tokenized using a SubwordTextEncoder, and
9871037 then split into examples, each of length self.sequence_length().
9881038 """
9891039
990- def train_text_filenames (self , tmp_dir ):
991- """Local filenames of text files containing training data.
1040+ def train_text_filepaths (self , tmp_dir ):
1041+ """Local filepaths of text files containing training data.
9921042
9931043 This function may want to download the files if they do not exist.
9941044
@@ -999,8 +1049,8 @@ def train_text_filenames(self, tmp_dir):
9991049 """
10001050 raise NotImplementedError ()
10011051
1002- def dev_text_filenames (self , tmp_dir ):
1003- """Local filenames of text files containing dev data.
1052+ def dev_text_filepaths (self , tmp_dir ):
1053+ """Local filepaths of text files containing dev data.
10041054
10051055 This function may want to download the files if they do not exist.
10061056
@@ -1016,15 +1066,15 @@ def sequence_length(self):
10161066 """Length of each example (in tokens)."""
10171067 raise NotImplementedError ()
10181068
1019- def max_length (self , unused_model_hparams ):
1020- return self .sequence_length
1069+ def max_length (self , model_hparams ):
1070+ return model_hparams . split_to_length or self .sequence_length
10211071
10221072 @property
10231073 def is_character_level (self ):
10241074 return False
10251075
1026- def text_filenames_for_task (self , tmp_dir , task_id ):
1027- """List of input filenames for a particular training or dev shard.
1076+ def text_filepaths_for_task (self , tmp_dir , task_id ):
1077+ """List of input filepaths for a particular training or dev shard.
10281078
10291079 Args:
10301080 tmp_dir: a string
@@ -1035,49 +1085,69 @@ def text_filenames_for_task(self, tmp_dir, task_id):
10351085 assert task_id >= 0
10361086 assert task_id < self .num_train_shards + self .num_dev_shards
10371087 if task_id < self .num_train_shards :
1038- return [f for i , f in enumerate (self .train_text_filenames (tmp_dir ))
1088+ return [f for i , f in enumerate (self .train_text_filepaths (tmp_dir ))
10391089 if i % self .num_train_shards == task_id ]
10401090 else :
1041- return [f for i , f in enumerate (self .dev_text_filenames (tmp_dir ))
1091+ return [f for i , f in enumerate (self .dev_text_filepaths (tmp_dir ))
10421092 if i % self .num_dev_shards == task_id - self .num_train_shards ]
10431093
1044- def filename_to_unicode_text (self , filename ):
1094+ def filepath_to_unicode_strings (self , filepath ):
10451095 """Read text out of an input file.
10461096
1047- The default just reads the text and converts to unicode.
1097+ The default just reads the text, converts to unicode and yields one
1098+ unicode string.
10481099
1049- Subclasses can override this function in order to preprocess.
1100+ Subclasses can override this function in order to preprocess, and can
1101+ yield any number of strings.
10501102
10511103 Args:
1052- filename : a string
1053- Returns :
1054- a unicode string .
1104+ filepath : a string
1105+ Yields :
1106+ unicode strings .
10551107 """
1056- f = tf .gfile .Open (filename )
1108+ f = tf .gfile .Open (filepath )
10571109 b = f .read ()
1058- return to_unicode_ignore_erros (b )
1110+ yield to_unicode_ignore_erros (b )
1111+
1112+ def file_generator (self ,
1113+ filepaths ,
1114+ max_chars_per_file = None ,
1115+ max_chars_total = None ):
1116+ """Read complete text of input files and yield unicode strings.
1117+
1118+ By default, one unicode string is produced per file, but this is
1119+ not guaranteed, since subclasses can override
1120+ filepath_to_unicode_strings().
10591121
1060- def file_generator (self , tmp_dir , task_id , max_files = None ):
1061- """Reads complete text of input files and returns as unicode.
1122+ max_chars_per_file and max_chars_total can also be specified, in which
1123+ case some strings may be truncated or dropped to limit the total
1124+ amount of output.
10621125
10631126 Args:
1064- tmp_dir : a string
1065- task_id : an integer less than num_shards, or "train" for training shards
1066- max_files : an optional integer
1127+ filepaths : a list of strings
1128+ max_chars_per_file : an optional integer
1129+ max_chars_total : an optional integer
10671130 Yields:
10681131 unicode strings
10691132 """
1070- count = 0
1071- if task_id == "train" :
1072- fnames = self .train_text_filenames (tmp_dir )
1073- else :
1074- fnames = self .text_filenames_for_task (tmp_dir , task_id )
1075- for fname in fnames :
1133+ chars_total = 0
1134+ for fname in filepaths :
1135+ chars_this_file = 0
10761136 tf .logging .info ("reading file %s" % fname )
1077- yield self .filename_to_unicode_text (fname )
1078- count += 1
1079- if max_files and count == max_files :
1080- return
1137+ for text in self .filepath_to_unicode_strings (fname ):
1138+ if (max_chars_per_file and chars_this_file + len (text )
1139+ > max_chars_per_file ):
1140+ text = text [:max_chars_per_file - chars_this_file ]
1141+ if max_chars_total and chars_total + len (text ) > max_chars_total :
1142+ text = text [:max_chars_total - chars_total ]
1143+ chars_total += len (text )
1144+ chars_this_file += len (text )
1145+ if text :
1146+ yield text
1147+ if max_chars_per_file and chars_this_file >= max_chars_per_file :
1148+ break
1149+ if max_chars_total and chars_total >= max_chars_total :
1150+ break
10811151
10821152 def example_generator (self , encoder , tmp_dir , task_id ):
10831153 """Generator for examples.
@@ -1089,17 +1159,29 @@ def example_generator(self, encoder, tmp_dir, task_id):
10891159 Yields:
10901160 feature dictionaries
10911161 """
1092- for ftext in self .file_generator (tmp_dir , task_id ):
1093- encoded = encoder .encode (ftext )
1094- for start_pos in xrange (0 , len (encoded ), self .sequence_length ):
1095- targets = encoded [start_pos :start_pos + self .sequence_length ]
1096- if len (targets ) < self .sequence_length :
1097- if self .remainder_policy == "pad" :
1098- targets += [0 ] * (self .sequence_length - len (targets ))
1099- else :
1100- assert self .remainder_policy == "drop"
1101- continue
1162+ filepaths = self .text_filepaths_for_task (tmp_dir , task_id )
1163+ if task_id >= self .num_train_shards :
1164+ # this is dev data - limit the total length.
1165+ max_chars_per_file = self .max_dev_chars // (
1166+ self .num_dev_shards * len (filepaths ))
1167+ else :
1168+ max_chars_per_file = None
1169+ tokens = []
1170+ for ftext in self .file_generator (
1171+ filepaths , max_chars_per_file = max_chars_per_file ):
1172+ tokens .extend (encoder .encode (ftext ))
1173+ pos = 0
1174+ while pos + self .sequence_length <= len (tokens ):
1175+ yield {"inputs" : [0 ], "targets" : tokens [pos :pos + self .sequence_length ]}
1176+ pos += self .sequence_length
1177+ if pos > 0 :
1178+ tokens = tokens [pos :]
1179+ if self .remainder_policy == "pad" :
1180+ if tokens :
1181+ targets = tokens + [0 ] * (self .sequence_length - len (tokens ))
11021182 yield {"inputs" : [0 ], "targets" : targets }
1183+ else :
1184+ assert self .remainder_policy == "drop"
11031185
11041186 @property
11051187 def remainder_policy (self ):
@@ -1113,14 +1195,15 @@ def remainder_policy(self):
11131195 def prepare_to_generate (self , data_dir , tmp_dir ):
11141196 """Make sure that the data is prepared and the vocab is generated."""
11151197 self .get_or_generate_vocab (data_dir , tmp_dir )
1116- self .train_text_filenames (tmp_dir )
1117- self .dev_text_filenames (tmp_dir )
1198+ self .train_text_filepaths (tmp_dir )
1199+ self .dev_text_filepaths (tmp_dir )
11181200
11191201 def get_or_generate_vocab (self , data_dir , tmp_dir ):
11201202 return generator_utils .get_or_generate_vocab_inner (
11211203 data_dir , self .vocab_file , self .targeted_vocab_size ,
11221204 self .file_generator (
1123- tmp_dir , task_id = "train" , max_files = self .max_files_for_vocab ))
1205+ self .train_text_filepaths (tmp_dir ),
1206+ max_chars_total = self .max_chars_for_vocab ))
11241207
11251208 def generate_data (self , data_dir , tmp_dir , task_id = - 1 ):
11261209 """Generates training/dev data.
@@ -1147,9 +1230,9 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
11471230 generator_utils .shuffle_dataset ([out_file ])
11481231
11491232 @property
1150- def max_files_for_vocab (self ):
1151- """Number of input files to read when generating vocab."""
1152- return 10
1233+ def max_chars_for_vocab (self ):
1234+ """Number of characters of training data to use for generating vocab."""
1235+ return 10 ** 7
11531236
11541237 @property
11551238 def target_space_id (self ):
@@ -1163,6 +1246,11 @@ def num_train_shards(self):
11631246 def num_dev_shards (self ):
11641247 return 1
11651248
1249+ @property
1250+ def max_dev_chars (self ):
1251+ """Limit dev set to at most this many characters (default 10M)."""
1252+ return 10 ** 7
1253+
11661254 @property
11671255 def multiprocess_generate (self ):
11681256 return True
0 commit comments