|
19 | 19 | from __future__ import division |
20 | 20 | from __future__ import print_function |
21 | 21 |
|
| 22 | +import glob |
22 | 23 | import os |
| 24 | +import stat |
23 | 25 | import tarfile |
24 | 26 |
|
25 | 27 | # Dependency imports |
@@ -264,6 +266,10 @@ def bi_vocabs_token_generator(source_path, |
264 | 266 |
|
265 | 267 | # English-Czech datasets |
266 | 268 | _ENCS_TRAIN_DATASETS = [ |
| 269 | + [ |
| 270 | + "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1458/data-plaintext-format.tar", |
| 271 | + ('tsv', 3, 2, 'data.plaintext-format/*train.gz') |
| 272 | + ], |
267 | 273 | [ |
268 | 274 | "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long |
269 | 275 | ("training/news-commentary-v12.cs-en.en", |
@@ -369,38 +375,64 @@ def _compile_data(tmp_dir, datasets, filename): |
369 | 375 | url = dataset[0] |
370 | 376 | compressed_filename = os.path.basename(url) |
371 | 377 | compressed_filepath = os.path.join(tmp_dir, compressed_filename) |
372 | | - |
373 | | - lang1_filename, lang2_filename = dataset[1] |
374 | | - lang1_filepath = os.path.join(tmp_dir, lang1_filename) |
375 | | - lang2_filepath = os.path.join(tmp_dir, lang2_filename) |
376 | | - is_sgm = (lang1_filename.endswith("sgm") and |
377 | | - lang2_filename.endswith("sgm")) |
378 | | - |
379 | 378 | generator_utils.maybe_download(tmp_dir, compressed_filename, url) |
380 | | - if not (os.path.exists(lang1_filepath) and |
381 | | - os.path.exists(lang2_filepath)): |
382 | | - # For .tar.gz and .tgz files, we read compressed. |
383 | | - mode = "r:gz" if compressed_filepath.endswith("gz") else "r" |
384 | | - with tarfile.open(compressed_filepath, mode) as corpus_tar: |
385 | | - corpus_tar.extractall(tmp_dir) |
386 | | - if lang1_filepath.endswith(".gz"): |
387 | | - new_filepath = lang1_filepath.strip(".gz") |
388 | | - generator_utils.gunzip_file(lang1_filepath, new_filepath) |
389 | | - lang1_filepath = new_filepath |
390 | | - if lang2_filepath.endswith(".gz"): |
391 | | - new_filepath = lang2_filepath.strip(".gz") |
392 | | - generator_utils.gunzip_file(lang2_filepath, new_filepath) |
393 | | - lang2_filepath = new_filepath |
394 | | - with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file: |
395 | | - with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file: |
396 | | - line1, line2 = lang1_file.readline(), lang2_file.readline() |
397 | | - while line1 or line2: |
398 | | - line1res = _preprocess_sgm(line1, is_sgm) |
399 | | - line2res = _preprocess_sgm(line2, is_sgm) |
400 | | - if line1res or line2res: |
401 | | - lang1_resfile.write(line1res.strip() + "\n") |
402 | | - lang2_resfile.write(line2res.strip() + "\n") |
| 379 | + |
| 380 | + if dataset[1][0] == 'tsv': |
| 381 | + _, src_column, trg_column, glob_pattern = dataset[1] |
| 382 | + filenames = glob.glob(os.path.join(tmp_dir, glob_pattern)) |
| 383 | + if not filenames: |
| 384 | + mode = "r:gz" if compressed_filepath.endswith("gz") else "r" # *.tgz *.tar.gz |
| 385 | + with tarfile.open(compressed_filepath, mode) as corpus_tar: |
| 386 | + corpus_tar.extractall(tmp_dir) |
| 387 | + filenames = glob.glob(os.path.join(tmp_dir, glob_pattern)) |
| 388 | + for tsv_filename in filenames: |
| 389 | + if tsv_filename.endswith(".gz"): |
| 390 | + new_filename = tsv_filename.strip(".gz") |
| 391 | + try: |
| 392 | + generator_utils.gunzip_file(tsv_filename, new_filename) |
| 393 | + except PermissionError: |
| 394 | + tsvdir = os.path.dirname(tsv_filename) |
| 395 | + os.chmod(tsvdir, os.stat(tsvdir).st_mode | stat.S_IWRITE) |
| 396 | + generator_utils.gunzip_file(tsv_filename, new_filename) |
| 397 | + tsv_filename = new_filename |
| 398 | + with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file: |
| 399 | + for line in tsv_file: |
| 400 | + if line and "\t" in line: |
| 401 | + parts = line.split("\t") |
| 402 | + source, target = parts[src_column], parts[trg_column] |
| 403 | + lang1_resfile.write(source.strip() + "\n") |
| 404 | + lang2_resfile.write(target.strip() + "\n") |
| 405 | + else: |
| 406 | + lang1_filename, lang2_filename = dataset[1] |
| 407 | + lang1_filepath = os.path.join(tmp_dir, lang1_filename) |
| 408 | + lang2_filepath = os.path.join(tmp_dir, lang2_filename) |
| 409 | + is_sgm = (lang1_filename.endswith("sgm") and |
| 410 | + lang2_filename.endswith("sgm")) |
| 411 | + |
| 412 | + if not (os.path.exists(lang1_filepath) and |
| 413 | + os.path.exists(lang2_filepath)): |
| 414 | + # For .tar.gz and .tgz files, we read compressed. |
| 415 | + mode = "r:gz" if compressed_filepath.endswith("gz") else "r" |
| 416 | + with tarfile.open(compressed_filepath, mode) as corpus_tar: |
| 417 | + corpus_tar.extractall(tmp_dir) |
| 418 | + if lang1_filepath.endswith(".gz"): |
| 419 | + new_filepath = lang1_filepath.strip(".gz") |
| 420 | + generator_utils.gunzip_file(lang1_filepath, new_filepath) |
| 421 | + lang1_filepath = new_filepath |
| 422 | + if lang2_filepath.endswith(".gz"): |
| 423 | + new_filepath = lang2_filepath.strip(".gz") |
| 424 | + generator_utils.gunzip_file(lang2_filepath, new_filepath) |
| 425 | + lang2_filepath = new_filepath |
| 426 | + with tf.gfile.GFile(lang1_filepath, mode="r") as lang1_file: |
| 427 | + with tf.gfile.GFile(lang2_filepath, mode="r") as lang2_file: |
403 | 428 | line1, line2 = lang1_file.readline(), lang2_file.readline() |
| 429 | + while line1 or line2: |
| 430 | + line1res = _preprocess_sgm(line1, is_sgm) |
| 431 | + line2res = _preprocess_sgm(line2, is_sgm) |
| 432 | + if line1res or line2res: |
| 433 | + lang1_resfile.write(line1res.strip() + "\n") |
| 434 | + lang2_resfile.write(line2res.strip() + "\n") |
| 435 | + line1, line2 = lang1_file.readline(), lang2_file.readline() |
404 | 436 |
|
405 | 437 | return filename |
406 | 438 |
|
@@ -630,13 +662,18 @@ def vocab_name(self): |
630 | 662 |
|
631 | 663 | def generator(self, data_dir, tmp_dir, train): |
632 | 664 | datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS |
633 | | - source_datasets = [[item[0], [item[1][0]]] for item in datasets] |
634 | | - target_datasets = [[item[0], [item[1][1]]] for item in datasets] |
635 | | - symbolizer_vocab = generator_utils.get_or_generate_vocab( |
636 | | - data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, |
637 | | - source_datasets + target_datasets) |
638 | 665 | tag = "train" if train else "dev" |
639 | 666 | data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag) |
| 667 | + vocab_datasets = [] |
| 668 | + # CzEng contains 100 gz files with tab-separated columns, so let's expect |
| 669 | + # it is the first dataset in datasets and use the newly created *.lang{1,2} files instead. |
| 670 | + if datasets[0][0].endswith("data-plaintext-format.tar"): |
| 671 | + vocab_datasets.append([datasets[0][0], |
| 672 | + ["wmt_encs_tok_%s.lang1" % tag, "wmt_encs_tok_%s.lang2" % tag]]) |
| 673 | + datasets = datasets[1:] |
| 674 | + vocab_datasets += [[item[0], [item[1][0], item[1][1]]] for item in datasets] |
| 675 | + symbolizer_vocab = generator_utils.get_or_generate_vocab( |
| 676 | + data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, vocab_datasets) |
640 | 677 | return token_generator(data_path + ".lang1", data_path + ".lang2", |
641 | 678 | symbolizer_vocab, EOS) |
642 | 679 |
|
|
0 commit comments