Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c34d16c

Browse files
authored
Merge pull request #248 from martinpopel/biger-encs
Bigger translate_encs_wmt32k training data, tsv support
2 parents 5a00243 + f8d5ee8 commit c34d16c

File tree

2 files changed

+77
-41
lines changed

2 files changed

+77
-41
lines changed

tensor2tensor/data_generators/generator_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,19 +350,18 @@ def generate():
350350
for source in sources:
351351
url = source[0]
352352
filename = os.path.basename(url)
353-
read_type = "r:gz" if "tgz" in filename else "r"
354-
355353
compressed_file = maybe_download(tmp_dir, filename, url)
356354

357-
with tarfile.open(compressed_file, read_type) as corpus_tar:
358-
corpus_tar.extractall(tmp_dir)
359-
360355
for lang_file in source[1]:
361356
tf.logging.info("Reading file: %s" % lang_file)
362357
filepath = os.path.join(tmp_dir, lang_file)
358+
if not tf.gfile.Exists(filepath):
359+
read_type = "r:gz" if filename.endswith("tgz") else "r"
360+
with tarfile.open(compressed_file, read_type) as corpus_tar:
361+
corpus_tar.extractall(tmp_dir)
363362

364363
# For some datasets a second extraction is necessary.
365-
if ".gz" in lang_file:
364+
if lang_file.endswith(".gz"):
366365
new_filepath = os.path.join(tmp_dir, lang_file[:-3])
367366
if tf.gfile.Exists(new_filepath):
368367
tf.logging.info(

tensor2tensor/data_generators/wmt.py

Lines changed: 72 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import glob
2223
import os
24+
import stat
2325
import tarfile
2426

2527
# Dependency imports
@@ -264,6 +266,10 @@ def bi_vocabs_token_generator(source_path,
264266

265267
# English-Czech datasets
266268
_ENCS_TRAIN_DATASETS = [
269+
[
270+
"https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1458/data-plaintext-format.tar",
271+
('tsv', 3, 2, 'data.plaintext-format/*train.gz')
272+
],
267273
[
268274
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long
269275
("training/news-commentary-v12.cs-en.en",
@@ -369,38 +375,64 @@ def _compile_data(tmp_dir, datasets, filename):
369375
url = dataset[0]
370376
compressed_filename = os.path.basename(url)
371377
compressed_filepath = os.path.join(tmp_dir, compressed_filename)
372-
373-
lang1_filename, lang2_filename = dataset[1]
374-
lang1_filepath = os.path.join(tmp_dir, lang1_filename)
375-
lang2_filepath = os.path.join(tmp_dir, lang2_filename)
376-
is_sgm = (lang1_filename.endswith("sgm") and
377-
lang2_filename.endswith("sgm"))
378-
379378
generator_utils.maybe_download(tmp_dir, compressed_filename, url)
380-
if not (os.path.exists(lang1_filepath) and
381-
os.path.exists(lang2_filepath)):
382-
# For .tar.gz and .tgz files, we read compressed.
383-
mode = "r:gz" if compressed_filepath.endswith("gz") else "r"
384-
with tarfile.open(compressed_filepath, mode) as corpus_tar:
385-
corpus_tar.extractall(tmp_dir)
386-
if lang1_filepath.endswith(".gz"):
387-
new_filepath = lang1_filepath.strip(".gz")
388-
generator_utils.gunzip_file(lang1_filepath, new_filepath)
389-
lang1_filepath = new_filepath
390-
if lang2_filepath.endswith(".gz"):
391-
new_filepath = lang2_filepath.strip(".gz")
392-
generator_utils.gunzip_file(lang2_filepath, new_filepath)
393-
lang2_filepath = new_filepath
394-
with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file:
395-
with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file:
396-
line1, line2 = lang1_file.readline(), lang2_file.readline()
397-
while line1 or line2:
398-
line1res = _preprocess_sgm(line1, is_sgm)
399-
line2res = _preprocess_sgm(line2, is_sgm)
400-
if line1res or line2res:
401-
lang1_resfile.write(line1res.strip() + "\n")
402-
lang2_resfile.write(line2res.strip() + "\n")
379+
380+
if dataset[1][0] == 'tsv':
381+
_, src_column, trg_column, glob_pattern = dataset[1]
382+
filenames = glob.glob(os.path.join(tmp_dir, glob_pattern))
383+
if not filenames:
384+
mode = "r:gz" if compressed_filepath.endswith("gz") else "r" # *.tgz *.tar.gz
385+
with tarfile.open(compressed_filepath, mode) as corpus_tar:
386+
corpus_tar.extractall(tmp_dir)
387+
filenames = glob.glob(os.path.join(tmp_dir, glob_pattern))
388+
for tsv_filename in filenames:
389+
if tsv_filename.endswith(".gz"):
390+
new_filename = tsv_filename.strip(".gz")
391+
try:
392+
generator_utils.gunzip_file(tsv_filename, new_filename)
393+
except PermissionError:
394+
tsvdir = os.path.dirname(tsv_filename)
395+
os.chmod(tsvdir, os.stat(tsvdir).st_mode | stat.S_IWRITE)
396+
generator_utils.gunzip_file(tsv_filename, new_filename)
397+
tsv_filename = new_filename
398+
with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file:
399+
for line in tsv_file:
400+
if line and "\t" in line:
401+
parts = line.split("\t")
402+
source, target = parts[src_column], parts[trg_column]
403+
lang1_resfile.write(source.strip() + "\n")
404+
lang2_resfile.write(target.strip() + "\n")
405+
else:
406+
lang1_filename, lang2_filename = dataset[1]
407+
lang1_filepath = os.path.join(tmp_dir, lang1_filename)
408+
lang2_filepath = os.path.join(tmp_dir, lang2_filename)
409+
is_sgm = (lang1_filename.endswith("sgm") and
410+
lang2_filename.endswith("sgm"))
411+
412+
if not (os.path.exists(lang1_filepath) and
413+
os.path.exists(lang2_filepath)):
414+
# For .tar.gz and .tgz files, we read compressed.
415+
mode = "r:gz" if compressed_filepath.endswith("gz") else "r"
416+
with tarfile.open(compressed_filepath, mode) as corpus_tar:
417+
corpus_tar.extractall(tmp_dir)
418+
if lang1_filepath.endswith(".gz"):
419+
new_filepath = lang1_filepath.strip(".gz")
420+
generator_utils.gunzip_file(lang1_filepath, new_filepath)
421+
lang1_filepath = new_filepath
422+
if lang2_filepath.endswith(".gz"):
423+
new_filepath = lang2_filepath.strip(".gz")
424+
generator_utils.gunzip_file(lang2_filepath, new_filepath)
425+
lang2_filepath = new_filepath
426+
with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file:
427+
with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file:
403428
line1, line2 = lang1_file.readline(), lang2_file.readline()
429+
while line1 or line2:
430+
line1res = _preprocess_sgm(line1, is_sgm)
431+
line2res = _preprocess_sgm(line2, is_sgm)
432+
if line1res or line2res:
433+
lang1_resfile.write(line1res.strip() + "\n")
434+
lang2_resfile.write(line2res.strip() + "\n")
435+
line1, line2 = lang1_file.readline(), lang2_file.readline()
404436

405437
return filename
406438

@@ -630,13 +662,18 @@ def vocab_name(self):
630662

631663
def generator(self, data_dir, tmp_dir, train):
632664
datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS
633-
source_datasets = [[item[0], [item[1][0]]] for item in datasets]
634-
target_datasets = [[item[0], [item[1][1]]] for item in datasets]
635-
symbolizer_vocab = generator_utils.get_or_generate_vocab(
636-
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size,
637-
source_datasets + target_datasets)
638665
tag = "train" if train else "dev"
639666
data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag)
667+
vocab_datasets = []
668+
# CzEng contains 100 gz files with tab-separated columns, so let's expect
669+
# it is the first dataset in datasets and use the newly created *.lang{1,2} files instead.
670+
if datasets[0][0].endswith("data-plaintext-format.tar"):
671+
vocab_datasets.append([datasets[0][0],
672+
["wmt_encs_tok_%s.lang1" % tag, "wmt_encs_tok_%s.lang2" % tag]])
673+
datasets = datasets[1:]
674+
vocab_datasets += [[item[0], [item[1][0], item[1][1]]] for item in datasets]
675+
symbolizer_vocab = generator_utils.get_or_generate_vocab(
676+
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, vocab_datasets)
640677
return token_generator(data_path + ".lang1", data_path + ".lang2",
641678
symbolizer_vocab, EOS)
642679

0 commit comments

Comments
 (0)