Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 1b1d7ed

Browse files
author
Ryan Sepassi
committed
Port WMT en-de tokens 8k/32k to new Problem registry
PiperOrigin-RevId: 162025600
1 parent 4b4c800 commit 1b1d7ed

File tree

7 files changed

+109
-84
lines changed

7 files changed

+109
-84
lines changed

tensor2tensor/bin/t2t-datagen

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
134134
"wmt_ende_bpe32k": (
135135
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True),
136136
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)),
137-
"wmt_ende_tokens_8k": (
138-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13),
139-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13)
140-
),
141-
"wmt_ende_tokens_32k": (
142-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
143-
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
144-
),
145137
"wmt_zhen_tokens_32k": (
146138
lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True,
147139
2**15, 2**15),

tensor2tensor/data_generators/algorithmic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ def num_symbols(self):
3737
return 2
3838

3939
def generate_data(self, data_dir, _):
40-
utils.generate_files(
40+
utils.generate_dataset_and_shuffle(
4141
identity_generator(self.num_symbols, 40, 100000),
42-
self.training_filepaths(data_dir, 100))
43-
utils.generate_files(
42+
self.training_filepaths(data_dir, 100, shuffled=True),
4443
identity_generator(self.num_symbols, 400, 10000),
45-
self.dev_filepaths(data_dir, 1))
44+
self.dev_filepaths(data_dir, 1, shuffled=True),
45+
shuffle=False)
4646

4747
def hparams(self, defaults, unused_model_hparams):
4848
p = defaults

tensor2tensor/data_generators/generator_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,17 @@ def write_records(records, out_filename):
359359
writer.close()
360360

361361

362+
def generate_dataset_and_shuffle(train_gen,
363+
train_paths,
364+
dev_gen,
365+
dev_paths,
366+
shuffle=True):
367+
generate_files(train_gen, train_paths)
368+
generate_files(dev_gen, dev_paths)
369+
if shuffle:
370+
shuffle_dataset(train_paths + dev_paths)
371+
372+
362373
def shuffle_dataset(filenames):
363374
tf.logging.info("Shuffling data...")
364375
for fname in filenames:

tensor2tensor/data_generators/problem.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,18 @@ class Problem(object):
7878
New problems are specified by the following methods:
7979
8080
Data generation:
81-
* generate_data(data_dir)
81+
* generate_data(data_dir, tmp_dir)
8282
- Generate training and dev datasets into data_dir.
8383
- Additonal files, e.g. vocabulary files, should also be written to
8484
data_dir.
85+
- Downloads and other files can be written to tmp_dir
86+
- If you have a training and dev generator, you can generate the
87+
training and dev datasets with
88+
generator_utils.generate_dataset_and_shuffle.
8589
- Use the self.training_filepaths and self.dev_filepaths functions to
86-
get sharded filenames.
90+
get sharded filenames. If shuffled=False, the filenames will contain
91+
an "unshuffled" suffix; you should then shuffle the data
92+
shard-by-shard with generator_utils.shuffle_dataset.
8793
- Subclasses must override
8894
* dataset_filename()
8995
- Base filename for problem.
@@ -125,13 +131,17 @@ def feature_encoders(self, data_dir):
125131
# END SUBCLASS INTERFACE
126132
# ============================================================================
127133

128-
def training_filepaths(self, data_dir, num_shards):
129-
return utils.train_data_filenames(self.dataset_filename(), data_dir,
130-
num_shards)
131-
132-
def dev_filepaths(self, data_dir, num_shards):
133-
return utils.dev_data_filenames(self.dataset_filename(), data_dir,
134-
num_shards)
134+
def training_filepaths(self, data_dir, num_shards, shuffled):
135+
file_basename = self.dataset_filename()
136+
if not shuffled:
137+
file_basename += utils.UNSHUFFLED_SUFFIX
138+
return utils.train_data_filenames(file_basename, data_dir, num_shards)
139+
140+
def dev_filepaths(self, data_dir, num_shards, shuffled):
141+
file_basename = self.dataset_filename()
142+
if not shuffled:
143+
file_basename += utils.UNSHUFFLED_SUFFIX
144+
return utils.dev_data_filenames(file_basename, data_dir, num_shards)
135145

136146
def __init__(self, was_reversed=False, was_copy=False):
137147
"""Create a Problem.

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -456,26 +456,6 @@ def wmt_ende_characters(unused_model_hparams):
456456
return p
457457

458458

459-
def wmt_ende_tokens(model_hparams, wrong_vocab_size):
460-
"""English to German translation benchmark."""
461-
p = default_problem_hparams()
462-
# This vocab file must be present within the data directory.
463-
vocab_filename = os.path.join(model_hparams.data_dir,
464-
"tokens.vocab.%d" % wrong_vocab_size)
465-
subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename)
466-
p.input_modality = {
467-
"inputs": (registry.Modalities.SYMBOL, subtokenizer.vocab_size)
468-
}
469-
p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size)
470-
p.vocabulary = {
471-
"inputs": subtokenizer,
472-
"targets": subtokenizer,
473-
}
474-
p.input_space_id = 3
475-
p.target_space_id = 8
476-
return p
477-
478-
479459
def wmt_zhen_tokens(model_hparams, wrong_vocab_size):
480460
"""Chinese to English translation benchmark."""
481461
p = default_problem_hparams()
@@ -751,9 +731,6 @@ def img2img_imagenet(unused_model_hparams):
751731
"wmt_enfr_tokens_32k_combined": lambda p: wmt_enfr_tokens(p, 2**15),
752732
"wmt_enfr_tokens_128k": lambda p: wmt_enfr_tokens(p, 2**17),
753733
"wmt_ende_characters": wmt_ende_characters,
754-
"wmt_ende_tokens_8k": lambda p: wmt_ende_tokens(p, 2**13),
755-
"wmt_ende_tokens_32k": lambda p: wmt_ende_tokens(p, 2**15),
756-
"wmt_ende_tokens_128k": lambda p: wmt_ende_tokens(p, 2**17),
757734
"wmt_ende_bpe32k": wmt_ende_bpe32k,
758735
"wmt_zhen_tokens_32k": lambda p: wmt_zhen_tokens(p, 2**15),
759736
"image_cifar10_tune": image_cifar10,

tensor2tensor/data_generators/wmt.py

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,64 @@
2424
# Dependency imports
2525

2626
from tensor2tensor.data_generators import generator_utils
27+
from tensor2tensor.data_generators import problem
2728
from tensor2tensor.data_generators import text_encoder
2829
from tensor2tensor.data_generators import wsj_parsing
30+
from tensor2tensor.utils import registry
2931

3032
import tensorflow as tf
3133

32-
3334
tf.flags.DEFINE_string("ende_bpe_path", "", "Path to BPE files in tmp_dir."
3435
"Download from https://drive.google.com/open?"
3536
"id=0B_bZck-ksdkpM25jRUN2X2UxMm8")
3637

37-
3838
FLAGS = tf.flags.FLAGS
3939

4040

41+
@registry.register_problem("wmt_ende_tokens_8k")
42+
class WMTEnDeTokens8k(problem.Problem):
43+
"""Problem spec for WMT En-De translation."""
44+
45+
@property
46+
def target_vocab_size(self):
47+
return 2**13 # 8192
48+
49+
def feature_encoders(self, data_dir):
50+
return _default_wmt_feature_encoders(data_dir, self.target_vocab_size)
51+
52+
def generate_data(self, data_dir, tmp_dir):
53+
generator_utils.generate_dataset_and_shuffle(
54+
ende_wordpiece_token_generator(tmp_dir, True, self.target_vocab_size),
55+
self.training_filepaths(data_dir, 100, shuffled=False),
56+
ende_wordpiece_token_generator(tmp_dir, False, self.target_vocab_size),
57+
self.dev_filepaths(data_dir, 1, shuffled=False))
58+
59+
def hparams(self, defaults, unused_model_hparams):
60+
p = defaults
61+
vocab_size = self._encoders["inputs"].vocab_size
62+
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)}
63+
p.target_modality = (registry.Modalities.SYMBOL, vocab_size)
64+
p.input_space_id = problem.SpaceID.EN_TOK
65+
p.target_space_id = problem.SpaceID.DE_TOK
66+
67+
68+
@registry.register_problem("wmt_ende_tokens_32k")
69+
class WMTEnDeTokens32k(WMTEnDeTokens8k):
70+
71+
@property
72+
def target_vocab_size(self):
73+
return 2**15 # 32768
74+
75+
76+
def _default_wmt_feature_encoders(data_dir, target_vocab_size):
77+
vocab_filename = os.path.join(data_dir, "tokens.vocab.%d" % target_vocab_size)
78+
subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename)
79+
return {
80+
"inputs": subtokenizer,
81+
"targets": subtokenizer,
82+
}
83+
84+
4185
# End-of-sentence marker.
4286
EOS = text_encoder.EOS_TOKEN
4387

@@ -130,7 +174,8 @@ def token_generator(source_path, target_path, token_vocab, eos=None):
130174
source, target = source_file.readline(), target_file.readline()
131175

132176

133-
def bi_vocabs_token_generator(source_path, target_path,
177+
def bi_vocabs_token_generator(source_path,
178+
target_path,
134179
source_token_vocab,
135180
target_token_vocab,
136181
eos=None):
@@ -184,8 +229,8 @@ def ende_bpe_token_generator(tmp_dir, train):
184229
train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path)
185230
token_path = os.path.join(tmp_dir, "vocab.bpe.32000")
186231
token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path)
187-
return token_generator(train_path + ".en", train_path + ".de",
188-
token_vocab, EOS)
232+
return token_generator(train_path + ".en", train_path + ".de", token_vocab,
233+
EOS)
189234

190235

191236
_ENDE_TRAIN_DATASETS = [
@@ -240,22 +285,15 @@ def ende_bpe_token_generator(tmp_dir, train):
240285
],
241286
]
242287

243-
_ZHEN_TRAIN_DATASETS = [
244-
[
245-
("http://data.statmt.org/wmt17/translation-task/"
246-
"training-parallel-nc-v12.tgz"),
247-
("training/news-commentary-v12.zh-en.zh",
248-
"training/news-commentary-v12.zh-en.en")
249-
]
250-
]
288+
_ZHEN_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
289+
"training-parallel-nc-v12.tgz"),
290+
("training/news-commentary-v12.zh-en.zh",
291+
"training/news-commentary-v12.zh-en.en")]]
251292

252-
_ZHEN_TEST_DATASETS = [
253-
[
254-
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
255-
("dev/newsdev2017-zhen-src.zh",
256-
"dev/newsdev2017-zhen-ref.en")
257-
]
258-
]
293+
_ZHEN_TEST_DATASETS = [[
294+
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
295+
("dev/newsdev2017-zhen-src.zh", "dev/newsdev2017-zhen-ref.en")
296+
]]
259297

260298

261299
def _compile_data(tmp_dir, datasets, filename):
@@ -317,23 +355,21 @@ def ende_character_generator(tmp_dir, train):
317355
character_vocab, EOS)
318356

319357

320-
def zhen_wordpiece_token_generator(tmp_dir, train,
321-
source_vocab_size,
358+
def zhen_wordpiece_token_generator(tmp_dir, train, source_vocab_size,
322359
target_vocab_size):
323360
"""Wordpiece generator for the WMT'17 zh-en dataset."""
324361
datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS
325362
source_datasets = [[item[0], [item[1][0]]] for item in datasets]
326363
target_datasets = [[item[0], [item[1][1]]] for item in datasets]
327364
source_vocab = generator_utils.get_or_generate_vocab(
328-
tmp_dir, "tokens.vocab.zh.%d" % source_vocab_size,
329-
source_vocab_size, source_datasets)
365+
tmp_dir, "tokens.vocab.zh.%d" % source_vocab_size, source_vocab_size,
366+
source_datasets)
330367
target_vocab = generator_utils.get_or_generate_vocab(
331-
tmp_dir, "tokens.vocab.en.%d" % target_vocab_size,
332-
target_vocab_size, target_datasets)
368+
tmp_dir, "tokens.vocab.en.%d" % target_vocab_size, target_vocab_size,
369+
target_datasets)
333370
tag = "train" if train else "dev"
334371
data_path = _compile_data(tmp_dir, datasets, "wmt_zhen_tok_%s" % tag)
335-
return bi_vocabs_token_generator(data_path + ".lang1",
336-
data_path + ".lang2",
372+
return bi_vocabs_token_generator(data_path + ".lang1", data_path + ".lang2",
337373
source_vocab, target_vocab, EOS)
338374

339375

@@ -366,17 +402,15 @@ def parsing_character_generator(tmp_dir, train):
366402
return character_generator(text_filepath, tags_filepath, character_vocab, EOS)
367403

368404

369-
def tabbed_parsing_token_generator(tmp_dir, train, prefix,
370-
source_vocab_size, target_vocab_size):
405+
def tabbed_parsing_token_generator(tmp_dir, train, prefix, source_vocab_size,
406+
target_vocab_size):
371407
"""Generate source and target data from a single file."""
372408
source_vocab = generator_utils.get_or_generate_tabbed_vocab(
373409
tmp_dir, "parsing_train.pairs", 0,
374-
prefix + "_source.tokens.vocab.%d" % source_vocab_size,
375-
source_vocab_size)
410+
prefix + "_source.tokens.vocab.%d" % source_vocab_size, source_vocab_size)
376411
target_vocab = generator_utils.get_or_generate_tabbed_vocab(
377412
tmp_dir, "parsing_train.pairs", 1,
378-
prefix + "_target.tokens.vocab.%d" % target_vocab_size,
379-
target_vocab_size)
413+
prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size)
380414
filename = "parsing_%s" % ("train" if train else "dev")
381415
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
382416
return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS)
@@ -395,5 +429,5 @@ def parsing_token_generator(tmp_dir, train, vocab_size):
395429
tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size)
396430
filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev")
397431
tree_filepath = os.path.join(tmp_dir, filename)
398-
return wsj_parsing.token_generator(tree_filepath,
399-
symbolizer_vocab, symbolizer_vocab, EOS)
432+
return wsj_parsing.token_generator(tree_filepath, symbolizer_vocab,
433+
symbolizer_vocab, EOS)

tensor2tensor/utils/trainer_utils_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@
3434
@registry.register_problem
3535
class TinyAlgo(algorithmic.AlgorithmicIdentityBinary40):
3636

37-
def generate_data(self, data_dir):
37+
def generate_data(self, data_dir, _):
3838
generator_utils.generate_files(
3939
algorithmic.identity_generator(self.num_symbols, 40, 100000),
40-
self.training_filepaths(data_dir, 1), 100)
40+
self.training_filepaths(data_dir, 1, shuffled=True), 100)
4141
generator_utils.generate_files(
4242
algorithmic.identity_generator(self.num_symbols, 400, 10000),
43-
self.dev_filepaths(data_dir, 1), 100)
43+
self.dev_filepaths(data_dir, 1, shuffled=True), 100)
4444

4545

4646
@registry.register_hparams
@@ -61,7 +61,8 @@ def setUpClass(cls):
6161
# Generate a small test dataset
6262
FLAGS.problems = "tiny_algo"
6363
TrainerUtilsTest.data_dir = tf.test.get_temp_dir()
64-
registry.problem(FLAGS.problems).generate_data(TrainerUtilsTest.data_dir)
64+
registry.problem(FLAGS.problems).generate_data(TrainerUtilsTest.data_dir,
65+
None)
6566

6667
def testModelsImported(self):
6768
models = registry.list_models()

0 commit comments

Comments
 (0)