2424# Dependency imports
2525
2626from tensor2tensor .data_generators import generator_utils
27+ from tensor2tensor .data_generators import problem
2728from tensor2tensor .data_generators import text_encoder
2829from tensor2tensor .data_generators import wsj_parsing
30+ from tensor2tensor .utils import registry
2931
3032import tensorflow as tf
3133
32-
3334tf .flags .DEFINE_string ("ende_bpe_path" , "" , "Path to BPE files in tmp_dir."
3435 "Download from https://drive.google.com/open?"
3536 "id=0B_bZck-ksdkpM25jRUN2X2UxMm8" )
3637
37-
3838FLAGS = tf .flags .FLAGS
3939
4040
41+ @registry .register_problem ("wmt_ende_tokens_8k" )
42+ class WMTEnDeTokens8k (problem .Problem ):
43+ """Problem spec for WMT En-De translation."""
44+
45+ @property
46+ def target_vocab_size (self ):
47+ return 2 ** 13 # 8192
48+
49+ def feature_encoders (self , data_dir ):
50+ return _default_wmt_feature_encoders (data_dir , self .target_vocab_size )
51+
52+ def generate_data (self , data_dir , tmp_dir ):
53+ generator_utils .generate_dataset_and_shuffle (
54+ ende_wordpiece_token_generator (tmp_dir , True , self .target_vocab_size ),
55+ self .training_filepaths (data_dir , 100 , shuffled = False ),
56+ ende_wordpiece_token_generator (tmp_dir , False , self .target_vocab_size ),
57+ self .dev_filepaths (data_dir , 1 , shuffled = False ))
58+
59+ def hparams (self , defaults , unused_model_hparams ):
60+ p = defaults
61+ vocab_size = self ._encoders ["inputs" ].vocab_size
62+ p .input_modality = {"inputs" : (registry .Modalities .SYMBOL , vocab_size )}
63+ p .target_modality = (registry .Modalities .SYMBOL , vocab_size )
64+ p .input_space_id = problem .SpaceID .EN_TOK
65+ p .target_space_id = problem .SpaceID .DE_TOK
66+
67+
68+ @registry .register_problem ("wmt_ende_tokens_32k" )
69+ class WMTEnDeTokens32k (WMTEnDeTokens8k ):
70+
71+ @property
72+ def target_vocab_size (self ):
73+ return 2 ** 15 # 32768
74+
75+
76+ def _default_wmt_feature_encoders (data_dir , target_vocab_size ):
77+ vocab_filename = os .path .join (data_dir , "tokens.vocab.%d" % target_vocab_size )
78+ subtokenizer = text_encoder .SubwordTextEncoder (vocab_filename )
79+ return {
80+ "inputs" : subtokenizer ,
81+ "targets" : subtokenizer ,
82+ }
83+
84+
4185# End-of-sentence marker.
4286EOS = text_encoder .EOS_TOKEN
4387
@@ -130,7 +174,8 @@ def token_generator(source_path, target_path, token_vocab, eos=None):
130174 source , target = source_file .readline (), target_file .readline ()
131175
132176
133- def bi_vocabs_token_generator (source_path , target_path ,
177+ def bi_vocabs_token_generator (source_path ,
178+ target_path ,
134179 source_token_vocab ,
135180 target_token_vocab ,
136181 eos = None ):
@@ -184,8 +229,8 @@ def ende_bpe_token_generator(tmp_dir, train):
184229 train_path = _get_wmt_ende_dataset (tmp_dir , dataset_path )
185230 token_path = os .path .join (tmp_dir , "vocab.bpe.32000" )
186231 token_vocab = text_encoder .TokenTextEncoder (vocab_filename = token_path )
187- return token_generator (train_path + ".en" , train_path + ".de" ,
188- token_vocab , EOS )
232+ return token_generator (train_path + ".en" , train_path + ".de" , token_vocab ,
233+ EOS )
189234
190235
191236_ENDE_TRAIN_DATASETS = [
@@ -240,22 +285,15 @@ def ende_bpe_token_generator(tmp_dir, train):
240285 ],
241286]
242287
243- _ZHEN_TRAIN_DATASETS = [
244- [
245- ("http://data.statmt.org/wmt17/translation-task/"
246- "training-parallel-nc-v12.tgz" ),
247- ("training/news-commentary-v12.zh-en.zh" ,
248- "training/news-commentary-v12.zh-en.en" )
249- ]
250- ]
288+ _ZHEN_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
289+ "training-parallel-nc-v12.tgz" ),
290+ ("training/news-commentary-v12.zh-en.zh" ,
291+ "training/news-commentary-v12.zh-en.en" )]]
251292
252- _ZHEN_TEST_DATASETS = [
253- [
254- "http://data.statmt.org/wmt17/translation-task/dev.tgz" ,
255- ("dev/newsdev2017-zhen-src.zh" ,
256- "dev/newsdev2017-zhen-ref.en" )
257- ]
258- ]
293+ _ZHEN_TEST_DATASETS = [[
294+ "http://data.statmt.org/wmt17/translation-task/dev.tgz" ,
295+ ("dev/newsdev2017-zhen-src.zh" , "dev/newsdev2017-zhen-ref.en" )
296+ ]]
259297
260298
261299def _compile_data (tmp_dir , datasets , filename ):
@@ -317,23 +355,21 @@ def ende_character_generator(tmp_dir, train):
317355 character_vocab , EOS )
318356
319357
320- def zhen_wordpiece_token_generator (tmp_dir , train ,
321- source_vocab_size ,
358+ def zhen_wordpiece_token_generator (tmp_dir , train , source_vocab_size ,
322359 target_vocab_size ):
323360 """Wordpiece generator for the WMT'17 zh-en dataset."""
324361 datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS
325362 source_datasets = [[item [0 ], [item [1 ][0 ]]] for item in datasets ]
326363 target_datasets = [[item [0 ], [item [1 ][1 ]]] for item in datasets ]
327364 source_vocab = generator_utils .get_or_generate_vocab (
328- tmp_dir , "tokens.vocab.zh.%d" % source_vocab_size ,
329- source_vocab_size , source_datasets )
365+ tmp_dir , "tokens.vocab.zh.%d" % source_vocab_size , source_vocab_size ,
366+ source_datasets )
330367 target_vocab = generator_utils .get_or_generate_vocab (
331- tmp_dir , "tokens.vocab.en.%d" % target_vocab_size ,
332- target_vocab_size , target_datasets )
368+ tmp_dir , "tokens.vocab.en.%d" % target_vocab_size , target_vocab_size ,
369+ target_datasets )
333370 tag = "train" if train else "dev"
334371 data_path = _compile_data (tmp_dir , datasets , "wmt_zhen_tok_%s" % tag )
335- return bi_vocabs_token_generator (data_path + ".lang1" ,
336- data_path + ".lang2" ,
372+ return bi_vocabs_token_generator (data_path + ".lang1" , data_path + ".lang2" ,
337373 source_vocab , target_vocab , EOS )
338374
339375
@@ -366,17 +402,15 @@ def parsing_character_generator(tmp_dir, train):
366402 return character_generator (text_filepath , tags_filepath , character_vocab , EOS )
367403
368404
369- def tabbed_parsing_token_generator (tmp_dir , train , prefix ,
370- source_vocab_size , target_vocab_size ):
405+ def tabbed_parsing_token_generator (tmp_dir , train , prefix , source_vocab_size ,
406+ target_vocab_size ):
371407 """Generate source and target data from a single file."""
372408 source_vocab = generator_utils .get_or_generate_tabbed_vocab (
373409 tmp_dir , "parsing_train.pairs" , 0 ,
374- prefix + "_source.tokens.vocab.%d" % source_vocab_size ,
375- source_vocab_size )
410+ prefix + "_source.tokens.vocab.%d" % source_vocab_size , source_vocab_size )
376411 target_vocab = generator_utils .get_or_generate_tabbed_vocab (
377412 tmp_dir , "parsing_train.pairs" , 1 ,
378- prefix + "_target.tokens.vocab.%d" % target_vocab_size ,
379- target_vocab_size )
413+ prefix + "_target.tokens.vocab.%d" % target_vocab_size , target_vocab_size )
380414 filename = "parsing_%s" % ("train" if train else "dev" )
381415 pair_filepath = os .path .join (tmp_dir , filename + ".pairs" )
382416 return tabbed_generator (pair_filepath , source_vocab , target_vocab , EOS )
@@ -395,5 +429,5 @@ def parsing_token_generator(tmp_dir, train, vocab_size):
395429 tmp_dir , "tokens.vocab.%d" % vocab_size , vocab_size )
396430 filename = "%s_%s.trees" % (FLAGS .parsing_path , "train" if train else "dev" )
397431 tree_filepath = os .path .join (tmp_dir , filename )
398- return wsj_parsing .token_generator (tree_filepath ,
399- symbolizer_vocab , symbolizer_vocab , EOS )
432+ return wsj_parsing .token_generator (tree_filepath , symbolizer_vocab ,
433+ symbolizer_vocab , EOS )
0 commit comments