@@ -222,15 +222,8 @@ def _maybe_pack_examples(self, generator):
222222 def generate_encoded_samples (self , data_dir , tmp_dir , dataset_split ):
223223 generator = self .generate_samples (data_dir , tmp_dir , dataset_split )
224224 encoder = self .get_or_create_vocab (data_dir , tmp_dir )
225- for sample in generator :
226- targets = encoder .encode (sample ["targets" ])
227- targets .append (text_encoder .EOS_ID )
228- encoded_sample = {"targets" : targets }
229- if self .has_inputs :
230- inputs = encoder .encode (sample ["inputs" ])
231- inputs .append (text_encoder .EOS_ID )
232- encoded_sample ["inputs" ] = inputs
233- yield encoded_sample
225+ return text2text_generate_encoded (generator , encoder ,
226+ has_inputs = self .has_inputs )
234227
235228 @property
236229 def batch_size_means_tokens (self ):
@@ -244,15 +237,15 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
244237 problem .DatasetSplit .TEST : self .test_filepaths ,
245238 }
246239
247- split_paths = dict ( [(split ["split" ], filepath_fns [split ["split" ]](
240+ split_paths = [(split ["split" ], filepath_fns [split ["split" ]](
248241 data_dir , split ["shards" ], shuffled = False ))
249- for split in self .dataset_splits ])
242+ for split in self .dataset_splits ]
250243 all_paths = []
251- for paths in split_paths . values () :
244+ for _ , paths in split_paths :
252245 all_paths .extend (paths )
253246
254247 if self .is_generate_per_split :
255- for split , paths in split_paths . items () :
248+ for split , paths in split_paths :
256249 generator_utils .generate_files (
257250 self ._maybe_pack_examples (
258251 self .generate_encoded_samples (data_dir , tmp_dir , split )), paths )
@@ -418,8 +411,7 @@ def example_reading_spec(self):
418411def txt_line_iterator (txt_path ):
419412 """Iterate through lines of file."""
420413 with tf .gfile .Open (txt_path ) as f :
421- readline = lambda : f .readline ()
422- for line in iter (readline , "" ):
414+ for line in f :
423415 yield line .strip ()
424416
425417
@@ -472,11 +464,26 @@ def text2text_txt_tab_iterator(txt_path):
472464 """
473465 for line in txt_line_iterator (txt_path ):
474466 if line and "\t " in line :
475- parts = line .split ("\t " )
467+ parts = line .split ("\t " , 1 )
476468 inputs , targets = parts [:2 ]
477469 yield {"inputs" : inputs .strip (), "targets" : targets .strip ()}
478470
479471
472+ def text2text_generate_encoded (sample_generator ,
473+ vocab ,
474+ targets_vocab = None ,
475+ has_inputs = True ):
476+ """Encode Text2Text samples from the generator with the vocab."""
477+ targets_vocab = targets_vocab or vocab
478+ for sample in sample_generator :
479+ if has_inputs :
480+ sample ["inputs" ] = vocab .encode (sample ["inputs" ])
481+ sample ["inputs" ].append (text_encoder .EOS_ID )
482+ sample ["targets" ] = targets_vocab .encode (sample ["targets" ])
483+ sample ["targets" ].append (text_encoder .EOS_ID )
484+ yield sample
485+
486+
480487@registry .register_problem
481488class Text2textTmpdir (Text2TextProblem ):
482489 """Allows training a Text2TextProblem without defining a subclass.
0 commit comments