Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f8d5ee8

Browse files
committed
bigger translate_encs_wmt32k training data, tsv support
* CzEng1.0 (15M sentence pairs) is part of the WMT training data, but has to be downloaded separately * This commits adds (a bit hacky, I admit) support for - src and trg sentences stored in arbitrary columns of tsv files - wildcard patters to support many (e.g. 100 in case of CzEng) files in tar
1 parent 6e82543 commit f8d5ee8

File tree

1 file changed

+72
-35
lines changed
  • tensor2tensor/data_generators

1 file changed

+72
-35
lines changed

tensor2tensor/data_generators/wmt.py

Lines changed: 72 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import glob
2223
import os
24+
import stat
2325
import tarfile
2426

2527
# Dependency imports
@@ -266,6 +268,10 @@ def bi_vocabs_token_generator(source_path,
266268

267269
# English-Czech datasets
268270
_ENCS_TRAIN_DATASETS = [
271+
[
272+
"https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1458/data-plaintext-format.tar",
273+
('tsv', 3, 2, 'data.plaintext-format/*train.gz')
274+
],
269275
[
270276
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long
271277
("training/news-commentary-v12.cs-en.en",
@@ -345,38 +351,64 @@ def _compile_data(tmp_dir, datasets, filename):
345351
url = dataset[0]
346352
compressed_filename = os.path.basename(url)
347353
compressed_filepath = os.path.join(tmp_dir, compressed_filename)
348-
349-
lang1_filename, lang2_filename = dataset[1]
350-
lang1_filepath = os.path.join(tmp_dir, lang1_filename)
351-
lang2_filepath = os.path.join(tmp_dir, lang2_filename)
352-
is_sgm = (lang1_filename.endswith("sgm") and
353-
lang2_filename.endswith("sgm"))
354-
355354
generator_utils.maybe_download(tmp_dir, compressed_filename, url)
356-
if not (os.path.exists(lang1_filepath) and
357-
os.path.exists(lang2_filepath)):
358-
# For .tar.gz and .tgz files, we read compressed.
359-
mode = "r:gz" if compressed_filepath.endswith("gz") else "r"
360-
with tarfile.open(compressed_filepath, mode) as corpus_tar:
361-
corpus_tar.extractall(tmp_dir)
362-
if lang1_filepath.endswith(".gz"):
363-
new_filepath = lang1_filepath.strip(".gz")
364-
generator_utils.gunzip_file(lang1_filepath, new_filepath)
365-
lang1_filepath = new_filepath
366-
if lang2_filepath.endswith(".gz"):
367-
new_filepath = lang2_filepath.strip(".gz")
368-
generator_utils.gunzip_file(lang2_filepath, new_filepath)
369-
lang2_filepath = new_filepath
370-
with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file:
371-
with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file:
372-
line1, line2 = lang1_file.readline(), lang2_file.readline()
373-
while line1 or line2:
374-
line1res = _preprocess_sgm(line1, is_sgm)
375-
line2res = _preprocess_sgm(line2, is_sgm)
376-
if line1res or line2res:
377-
lang1_resfile.write(line1res.strip() + "\n")
378-
lang2_resfile.write(line2res.strip() + "\n")
355+
356+
if dataset[1][0] == 'tsv':
357+
_, src_column, trg_column, glob_pattern = dataset[1]
358+
filenames = glob.glob(os.path.join(tmp_dir, glob_pattern))
359+
if not filenames:
360+
mode = "r:gz" if compressed_filepath.endswith("gz") else "r" # *.tgz *.tar.gz
361+
with tarfile.open(compressed_filepath, mode) as corpus_tar:
362+
corpus_tar.extractall(tmp_dir)
363+
filenames = glob.glob(os.path.join(tmp_dir, glob_pattern))
364+
for tsv_filename in filenames:
365+
if tsv_filename.endswith(".gz"):
366+
new_filename = tsv_filename.strip(".gz")
367+
try:
368+
generator_utils.gunzip_file(tsv_filename, new_filename)
369+
except PermissionError:
370+
tsvdir = os.path.dirname(tsv_filename)
371+
os.chmod(tsvdir, os.stat(tsvdir).st_mode | stat.S_IWRITE)
372+
generator_utils.gunzip_file(tsv_filename, new_filename)
373+
tsv_filename = new_filename
374+
with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file:
375+
for line in tsv_file:
376+
if line and "\t" in line:
377+
parts = line.split("\t")
378+
source, target = parts[src_column], parts[trg_column]
379+
lang1_resfile.write(source.strip() + "\n")
380+
lang2_resfile.write(target.strip() + "\n")
381+
else:
382+
lang1_filename, lang2_filename = dataset[1]
383+
lang1_filepath = os.path.join(tmp_dir, lang1_filename)
384+
lang2_filepath = os.path.join(tmp_dir, lang2_filename)
385+
is_sgm = (lang1_filename.endswith("sgm") and
386+
lang2_filename.endswith("sgm"))
387+
388+
if not (os.path.exists(lang1_filepath) and
389+
os.path.exists(lang2_filepath)):
390+
# For .tar.gz and .tgz files, we read compressed.
391+
mode = "r:gz" if compressed_filepath.endswith("gz") else "r"
392+
with tarfile.open(compressed_filepath, mode) as corpus_tar:
393+
corpus_tar.extractall(tmp_dir)
394+
if lang1_filepath.endswith(".gz"):
395+
new_filepath = lang1_filepath.strip(".gz")
396+
generator_utils.gunzip_file(lang1_filepath, new_filepath)
397+
lang1_filepath = new_filepath
398+
if lang2_filepath.endswith(".gz"):
399+
new_filepath = lang2_filepath.strip(".gz")
400+
generator_utils.gunzip_file(lang2_filepath, new_filepath)
401+
lang2_filepath = new_filepath
402+
with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file:
403+
with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file:
379404
line1, line2 = lang1_file.readline(), lang2_file.readline()
405+
while line1 or line2:
406+
line1res = _preprocess_sgm(line1, is_sgm)
407+
line2res = _preprocess_sgm(line2, is_sgm)
408+
if line1res or line2res:
409+
lang1_resfile.write(line1res.strip() + "\n")
410+
lang2_resfile.write(line2res.strip() + "\n")
411+
line1, line2 = lang1_file.readline(), lang2_file.readline()
380412

381413
return filename
382414

@@ -603,13 +635,18 @@ def vocab_name(self):
603635

604636
def generator(self, data_dir, tmp_dir, train):
605637
datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS
606-
source_datasets = [[item[0], [item[1][0]]] for item in datasets]
607-
target_datasets = [[item[0], [item[1][1]]] for item in datasets]
608-
symbolizer_vocab = generator_utils.get_or_generate_vocab(
609-
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size,
610-
source_datasets + target_datasets)
611638
tag = "train" if train else "dev"
612639
data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag)
640+
vocab_datasets = []
641+
# CzEng contains 100 gz files with tab-separated columns, so let's expect
642+
# it is the first dataset in datasets and use the newly created *.lang{1,2} files instead.
643+
if datasets[0][0].endswith("data-plaintext-format.tar"):
644+
vocab_datasets.append([datasets[0][0],
645+
["wmt_encs_tok_%s.lang1" % tag, "wmt_encs_tok_%s.lang2" % tag]])
646+
datasets = datasets[1:]
647+
vocab_datasets += [[item[0], [item[1][0], item[1][1]]] for item in datasets]
648+
symbolizer_vocab = generator_utils.get_or_generate_vocab(
649+
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, vocab_datasets)
613650
return token_generator(data_path + ".lang1", data_path + ".lang2",
614651
symbolizer_vocab, EOS)
615652

0 commit comments

Comments
 (0)