2828
2929import tensorflow as tf
3030
31+ # End-of-sentence marker (should correspond to the position of EOS in the
32+ # RESERVED_TOKENS list in text_encoder.py)
33+ EOS = 1
3134
32- def character_generator (source_path , target_path , eos = None ):
35+
36+ def character_generator (source_path , target_path , character_vocab , eos = None ):
3337 """Generator for sequence-to-sequence tasks that just uses characters.
3438
3539 This generator assumes the files at source_path and target_path have
@@ -51,8 +55,8 @@ def character_generator(source_path, target_path, eos=None):
5155 with tf .gfile .GFile (target_path , mode = "r" ) as target_file :
5256 source , target = source_file .readline (), target_file .readline ()
5357 while source and target :
54- source_ints = [ ord ( c ) for c in source .strip ()] + eos_list
55- target_ints = [ ord ( c ) for c in target .strip ()] + eos_list
58+ source_ints = character_vocab . encode ( source .strip ()) + eos_list
59+ target_ints = character_vocab . encode ( target .strip ()) + eos_list
5660 yield {"inputs" : source_ints , "targets" : target_ints }
5761 source , target = source_file .readline (), target_file .readline ()
5862
@@ -226,14 +230,16 @@ def ende_wordpiece_token_generator(tmp_dir, train, vocab_size):
226230 tag = "train" if train else "dev"
227231 data_path = _compile_data (tmp_dir , datasets , "wmt_ende_tok_%s" % tag )
228232 return token_generator (data_path + ".lang1" , data_path + ".lang2" ,
229- symbolizer_vocab , 1 )
233+ symbolizer_vocab , EOS )
230234
231235
232236def ende_character_generator (tmp_dir , train ):
237+ character_vocab = text_encoder .ByteTextEncoder ()
233238 datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS
234239 tag = "train" if train else "dev"
235240 data_path = _compile_data (tmp_dir , datasets , "wmt_ende_chr_%s" % tag )
236- return character_generator (data_path + ".lang1" , data_path + ".lang2" , 1 )
241+ return character_generator (data_path + ".lang1" , data_path + ".lang2" ,
242+ character_vocab , EOS )
237243
238244
239245def enfr_wordpiece_token_generator (tmp_dir , train , vocab_size ):
@@ -244,22 +250,25 @@ def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size):
244250 tag = "train" if train else "dev"
245251 data_path = _compile_data (tmp_dir , datasets , "wmt_enfr_tok_%s" % tag )
246252 return token_generator (data_path + ".lang1" , data_path + ".lang2" ,
247- symbolizer_vocab , 1 )
253+ symbolizer_vocab , EOS )
248254
249255
250256def enfr_character_generator (tmp_dir , train ):
251257 """Instance of character generator for the WMT en->fr task."""
258+ character_vocab = text_encoder .ByteTextEncoder ()
252259 datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
253260 tag = "train" if train else "dev"
254261 data_path = _compile_data (tmp_dir , datasets , "wmt_enfr_chr_%s" % tag )
255- return character_generator (data_path + ".lang1" , data_path + ".lang2" , 1 )
262+ return character_generator (data_path + ".lang1" , data_path + ".lang2" ,
263+ character_vocab , EOS )
256264
257265
258266def parsing_character_generator (tmp_dir , train ):
267+ character_vocab = text_encoder .ByteTextEncoder ()
259268 filename = "parsing_%s" % ("train" if train else "dev" )
260269 text_filepath = os .path .join (tmp_dir , filename + ".text" )
261270 tags_filepath = os .path .join (tmp_dir , filename + ".tags" )
262- return character_generator (text_filepath , tags_filepath , 1 )
271+ return character_generator (text_filepath , tags_filepath , character_vocab , EOS )
263272
264273
265274def parsing_token_generator (tmp_dir , train , vocab_size ):
@@ -268,4 +277,4 @@ def parsing_token_generator(tmp_dir, train, vocab_size):
268277 filename = "parsing_%s" % ("train" if train else "dev" )
269278 text_filepath = os .path .join (tmp_dir , filename + ".text" )
270279 tags_filepath = os .path .join (tmp_dir , filename + ".tags" )
271- return token_generator (text_filepath , tags_filepath , symbolizer_vocab , 1 )
280+ return token_generator (text_filepath , tags_filepath , symbolizer_vocab , EOS )
0 commit comments