Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 43bfb9f

Browse files
authored
Merge pull request #154 from vthorsteinsson/iceparse
Source/target pair text files; Icelandic parsing support; fixes
2 parents 4617c01 + 5a72e5c commit 43bfb9f

File tree

9 files changed

+223
-83
lines changed

9 files changed

+223
-83
lines changed

tensor2tensor/bin/t2t-datagen

100644100755
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ _SUPPORTED_PROBLEM_GENERATORS = {
102102
"algorithmic_algebra_inverse": (
103103
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
104104
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
105+
"ice_parsing_tokens": (
106+
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
107+
True, "ice", 2**13, 2**8),
108+
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
109+
False, "ice", 2**13, 2**8)),
110+
"ice_parsing_characters": (
111+
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True),
112+
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)),
105113
"wmt_parsing_tokens_8k": (
106114
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13),
107115
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)),

tensor2tensor/bin/t2t-trainer

100644100755
File mode changed.

tensor2tensor/data_generators/generator_utils.py

100644100755
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,34 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
290290
return vocab
291291

292292

293+
def get_or_generate_tabbed_vocab(tmp_dir, source_filename, index, vocab_filename, vocab_size):
294+
"""Generate a vocabulary from the source file. This is assumed to be
295+
a file of source, target pairs, where each line contains a source string
296+
and a target string, separated by a tab ('\t') character. The index
297+
parameter specifies 0 for the source or 1 for the target."""
298+
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
299+
if os.path.exists(vocab_filepath):
300+
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
301+
return vocab
302+
303+
# Use Tokenizer to count the word occurrences.
304+
token_counts = defaultdict(int)
305+
filepath = os.path.join(tmp_dir, source_filename)
306+
with tf.gfile.GFile(filepath, mode="r") as source_file:
307+
for line in source_file:
308+
line = line.strip()
309+
if line and '\t' in line:
310+
parts = line.split('\t', maxsplit = 1)
311+
part = parts[index].strip()
312+
for tok in tokenizer.encode(text_encoder.native_to_unicode(part)):
313+
token_counts[tok] += 1
314+
315+
vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
316+
vocab_size, token_counts, 1, 1e3)
317+
vocab.store_to_file(vocab_filepath)
318+
return vocab
319+
320+
293321
def read_records(filename):
294322
reader = tf.python_io.tf_record_iterator(filename)
295323
records = []

tensor2tensor/data_generators/problem_hparams.py

100644100755
Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,13 @@ def parse_problem_name(problem_name):
6666
was_copy: A boolean.
6767
"""
6868
# Recursively strip tags until we reach a base name.
69-
if len(problem_name) > 4 and problem_name[-4:] == "_rev":
69+
if problem_name.endswith("_rev"):
7070
base, _, was_copy = parse_problem_name(problem_name[:-4])
7171
return base, True, was_copy
72-
elif len(problem_name) > 5 and problem_name[-5:] == "_copy":
72+
if problem_name.endswith("_copy"):
7373
base, was_reversed, _ = parse_problem_name(problem_name[:-5])
7474
return base, was_reversed, True
75-
else:
76-
return problem_name, False, False
75+
return problem_name, False, False
7776

7877

7978
def _lookup_problem_hparams_fn(name):
@@ -178,6 +177,9 @@ def default_problem_hparams():
178177
# 14: Parse characters
179178
# 15: Parse tokens
180179
# 16: Chinese tokens
180+
# 17: Icelandic characters
181+
# 18: Icelandic tokens
182+
# 19: Icelandic parse tokens
181183
# Add more above if needed.
182184
input_space_id=0,
183185
target_space_id=0,
@@ -198,7 +200,8 @@ def default_problem_hparams():
198200
# the targets. For instance `problem_copy` will copy the inputs, but
199201
# `problem_rev_copy` will copy the targets.
200202
was_reversed=False,
201-
was_copy=False,)
203+
was_copy=False,
204+
)
202205

203206

204207
def test_problem_hparams(unused_model_hparams, input_vocab_size,
@@ -532,7 +535,7 @@ def wmt_concat(model_hparams, wrong_vocab_size):
532535
return p
533536

534537

535-
def wmt_parsing_characters(unused_model_hparams):
538+
def wmt_parsing_characters(model_hparams):
536539
"""English to parse tree translation benchmark."""
537540
p = default_problem_hparams()
538541
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)}
@@ -576,7 +579,8 @@ def wmt_parsing_tokens(model_hparams, wrong_vocab_size):
576579
return p
577580

578581

579-
def wsj_parsing_tokens(model_hparams, wrong_source_vocab_size,
582+
def wsj_parsing_tokens(model_hparams, prefix,
583+
wrong_source_vocab_size,
580584
wrong_target_vocab_size):
581585
"""English to parse tree translation benchmark.
582586
@@ -595,10 +599,10 @@ def wsj_parsing_tokens(model_hparams, wrong_source_vocab_size,
595599
# This vocab file must be present within the data directory.
596600
source_vocab_filename = os.path.join(
597601
model_hparams.data_dir,
598-
"wsj_source.tokens.vocab.%d" % wrong_source_vocab_size)
602+
prefix + "_source.tokens.vocab.%d" % wrong_source_vocab_size)
599603
target_vocab_filename = os.path.join(
600604
model_hparams.data_dir,
601-
"wsj_target.tokens.vocab.%d" % wrong_target_vocab_size)
605+
prefix + "_target.tokens.vocab.%d" % wrong_target_vocab_size)
602606
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
603607
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
604608
p.input_modality = {
@@ -615,6 +619,37 @@ def wsj_parsing_tokens(model_hparams, wrong_source_vocab_size,
615619
return p
616620

617621

622+
def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
623+
"""Icelandic to parse tree translation benchmark.
624+
625+
Args:
626+
model_hparams: a tf.contrib.training.HParams
627+
Returns:
628+
a tf.contrib.training.HParams
629+
"""
630+
p = default_problem_hparams()
631+
# This vocab file must be present within the data directory.
632+
source_vocab_filename = os.path.join(
633+
model_hparams.data_dir,
634+
"ice_source.tokens.vocab.%d" % wrong_source_vocab_size)
635+
target_vocab_filename = os.path.join(
636+
model_hparams.data_dir,
637+
"ice_target.tokens.vocab.256")
638+
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
639+
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
640+
p.input_modality = {
641+
"inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size)
642+
}
643+
p.target_modality = (registry.Modalities.SYMBOL, 256)
644+
p.vocabulary = {
645+
"inputs": source_subtokenizer,
646+
"targets": target_subtokenizer,
647+
}
648+
p.input_space_id = 18 # Icelandic tokens
649+
p.target_space_id = 19 # Icelandic parse tokens
650+
return p
651+
652+
618653
def image_cifar10(unused_model_hparams):
619654
"""CIFAR-10."""
620655
p = default_problem_hparams()
@@ -733,9 +768,11 @@ def img2img_imagenet(unused_model_hparams):
733768
"wiki_32k": wiki_32k,
734769
"lmptb_10k": lmptb_10k,
735770
"wmt_parsing_characters": wmt_parsing_characters,
771+
"ice_parsing_characters": wmt_parsing_characters,
772+
"ice_parsing_tokens": lambda p: ice_parsing_tokens(p, 2**13),
736773
"wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13),
737-
"wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens(p, 2**14, 2**9),
738-
"wsj_parsing_tokens_32k": lambda p: wsj_parsing_tokens(p, 2**15, 2**9),
774+
"wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens(p, "wsj", 2**14, 2**9),
775+
"wsj_parsing_tokens_32k": lambda p: wsj_parsing_tokens(p, "wsj", 2**15, 2**9),
739776
"wmt_enfr_characters": wmt_enfr_characters,
740777
"wmt_enfr_tokens_8k": lambda p: wmt_enfr_tokens(p, 2**13),
741778
"wmt_enfr_tokens_32k": lambda p: wmt_enfr_tokens(p, 2**15),

tensor2tensor/data_generators/text_encoder.py

100644100755
Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,32 @@
3737

3838

3939
# Conversion between Unicode and UTF-8, if required (on Python2)
40-
def native_to_unicode(s):
41-
return s.decode("utf-8") if (PY2 and not isinstance(s, unicode)) else s
42-
43-
44-
unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)
40+
if PY2:
41+
native_to_unicode = lambda s: s if isinstance(s, unicode) else s.decode("utf-8")
42+
unicode_to_native = lambda s: s.encode("utf-8")
43+
else:
44+
# No conversion required on Python3
45+
native_to_unicode = lambda s: s
46+
unicode_to_native = lambda s: s
4547

4648

4749
# Reserved tokens for things like padding and EOS symbols.
4850
PAD = "<pad>"
4951
EOS = "<EOS>"
5052
RESERVED_TOKENS = [PAD, EOS]
51-
if six.PY2:
53+
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
54+
PAD_TOKEN = RESERVED_TOKENS.index(PAD) # Normally 0
55+
EOS_TOKEN = RESERVED_TOKENS.index(EOS) # Normally 1
56+
57+
if PY2:
5258
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
5359
else:
5460
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
5561

56-
5762
class TextEncoder(object):
5863
"""Base class for converting from ints to/from human readable strings."""
5964

60-
def __init__(self, num_reserved_ids=2):
65+
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
6166
self._num_reserved_ids = num_reserved_ids
6267

6368
def encode(self, s):
@@ -105,7 +110,7 @@ class ByteTextEncoder(TextEncoder):
105110

106111
def encode(self, s):
107112
numres = self._num_reserved_ids
108-
if six.PY2:
113+
if PY2:
109114
return [ord(c) + numres for c in s]
110115
# Python3: explicitly convert to UTF-8
111116
return [c + numres for c in s.encode("utf-8")]
@@ -119,10 +124,10 @@ def decode(self, ids):
119124
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
120125
else:
121126
decoded_ids.append(int2byte(id_ - numres))
122-
if six.PY2:
127+
if PY2:
123128
return "".join(decoded_ids)
124129
# Python3: join byte arrays and then decode string
125-
return b"".join(decoded_ids).decode("utf-8")
130+
return b"".join(decoded_ids).decode("utf-8", "replace")
126131

127132
@property
128133
def vocab_size(self):
@@ -132,7 +137,7 @@ def vocab_size(self):
132137
class TokenTextEncoder(TextEncoder):
133138
"""Encoder based on a user-supplied vocabulary."""
134139

135-
def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
140+
def __init__(self, vocab_filename, reverse=False, num_reserved_ids=NUM_RESERVED_TOKENS):
136141
"""Initialize from a file, one token per line."""
137142
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
138143
self._reverse = reverse
@@ -345,7 +350,7 @@ def build_from_token_counts(self,
345350
token_counts,
346351
min_count,
347352
num_iterations=4,
348-
num_reserved_ids=2):
353+
num_reserved_ids=NUM_RESERVED_TOKENS):
349354
"""Train a SubwordTextEncoder based on a dictionary of word counts.
350355
351356
Args:
@@ -371,6 +376,8 @@ def build_from_token_counts(self,
371376
# We build iteratively. On each iteration, we segment all the words,
372377
# then count the resulting potential subtokens, keeping the ones
373378
# with high enough counts for our new vocabulary.
379+
if min_count < 1:
380+
min_count = 1
374381
for i in xrange(num_iterations):
375382
tf.logging.info("Iteration {0}".format(i))
376383
counts = defaultdict(int)
@@ -462,7 +469,7 @@ def store_to_file(self, filename):
462469
f.write("'" + unicode_to_native(subtoken_string) + "'\n")
463470

464471
def _escape_token(self, token):
465-
r"""Escape away underscores and OOV characters and append '_'.
472+
"""Escape away underscores and OOV characters and append '_'.
466473
467474
This allows the token to be experessed as the concatenation of a list
468475
of subtokens from the vocabulary. The underscore acts as a sentinel
@@ -484,7 +491,7 @@ def _escape_token(self, token):
484491
return ret
485492

486493
def _unescape_token(self, escaped_token):
487-
r"""Inverse of _escape_token().
494+
"""Inverse of _escape_token().
488495
489496
Args:
490497
escaped_token: a unicode string

tensor2tensor/data_generators/tokenizer.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def read_corpus():
141141
if corpus_max_lines > 0 and lines_read > corpus_max_lines:
142142
return docs
143143
return docs
144+
144145
counts = defaultdict(int)
145146
for doc in read_corpus():
146147
for tok in encode(_native_to_unicode(doc)):

tensor2tensor/data_generators/wmt.py

100644100755
Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@
3838
FLAGS = tf.flags.FLAGS
3939

4040

41-
# End-of-sentence marker (should correspond to the position of EOS in the
42-
# RESERVED_TOKENS list in text_encoder.py)
43-
EOS = 1
41+
# End-of-sentence marker
42+
EOS = text_encoder.EOS_TOKEN
4443

4544

4645
def character_generator(source_path, target_path, character_vocab, eos=None):
@@ -72,6 +71,35 @@ def character_generator(source_path, target_path, character_vocab, eos=None):
7271
source, target = source_file.readline(), target_file.readline()
7372

7473

74+
def tabbed_generator(source_path, source_vocab, target_vocab, eos=None):
75+
"""Generator for sequence-to-sequence tasks using tokens derived from
76+
text files where each line contains both a source and a target string.
77+
The two strings are separated by a tab character ('\t'). It yields
78+
dictionaries of "inputs" and "targets" where inputs are characters
79+
from the source lines converted to integers, and targets are
80+
characters from the target lines, also converted to integers.
81+
82+
Args:
83+
source_path: path to the file with source and target sentences.
84+
source_vocab: a SunwordTextEncoder to encode the source string.
85+
target_vocab: a SunwordTextEncoder to encode the target string.
86+
eos: integer to append at the end of each sequence (default: None).
87+
88+
Yields:
89+
A dictionary {"inputs": source-line, "targets": target-line} where
90+
the lines are integer lists converted from characters in the file lines.
91+
"""
92+
eos_list = [] if eos is None else [eos]
93+
with tf.gfile.GFile(source_path, mode="r") as source_file:
94+
for line in source_file:
95+
if line and '\t' in line:
96+
parts = line.split('\t', maxsplit = 1)
97+
source, target = parts[0].strip(), parts[1].strip()
98+
source_ints = source_vocab.encode(source) + eos_list
99+
target_ints = source_vocab.encode(target) + eos_list
100+
yield {"inputs": source_ints, "targets": target_ints}
101+
102+
75103
def token_generator(source_path, target_path, token_vocab, eos=None):
76104
"""Generator for sequence-to-sequence tasks that uses tokens.
77105
@@ -154,7 +182,7 @@ def ende_bpe_token_generator(tmp_dir, train):
154182
train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path)
155183
token_path = os.path.join(tmp_dir, "vocab.bpe.32000")
156184
token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path)
157-
return token_generator(train_path + ".en", train_path + ".de", token_vocab, 1)
185+
return token_generator(train_path + ".en", train_path + ".de", token_vocab, EOS)
158186

159187

160188
_ENDE_TRAIN_DATASETS = [
@@ -339,6 +367,38 @@ def enfr_character_generator(tmp_dir, train):
339367
return character_generator(data_path + ".lang1", data_path + ".lang2",
340368
character_vocab, EOS)
341369

370+
def parsing_character_generator(tmp_dir, train):
371+
character_vocab = text_encoder.ByteTextEncoder()
372+
filename = "parsing_%s" % ("train" if train else "dev")
373+
text_filepath = os.path.join(tmp_dir, filename + ".text")
374+
tags_filepath = os.path.join(tmp_dir, filename + ".tags")
375+
return character_generator(text_filepath, tags_filepath, character_vocab, EOS)
376+
377+
378+
def tabbed_parsing_token_generator(tmp_dir, train, prefix, source_vocab_size, target_vocab_size):
379+
"""Generate source and target data from a single file with source/target pairs
380+
separated by a tab character ('\t')"""
381+
source_vocab = generator_utils.get_or_generate_tabbed_vocab(
382+
tmp_dir, "parsing_train.pairs", 0,
383+
prefix + "_source.tokens.vocab.%d" % source_vocab_size,
384+
source_vocab_size)
385+
target_vocab = generator_utils.get_or_generate_tabbed_vocab(
386+
tmp_dir, "parsing_train.pairs", 1,
387+
prefix + "_target.tokens.vocab.%d" % target_vocab_size,
388+
target_vocab_size)
389+
filename = "parsing_%s" % ("train" if train else "dev")
390+
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
391+
return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS)
392+
393+
394+
def tabbed_parsing_character_generator(tmp_dir, train):
395+
"""Generate source and target data from a single file with source/target pairs
396+
separated by a tab character ('\t')"""
397+
character_vocab = text_encoder.ByteTextEncoder()
398+
filename = "parsing_%s" % ("train" if train else "dev")
399+
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
400+
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)
401+
342402

343403
def parsing_token_generator(tmp_dir, train, vocab_size):
344404
symbolizer_vocab = generator_utils.get_or_generate_vocab(

tensor2tensor/models/transformer.py

100644100755
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,15 @@ def transformer_parsing_base():
353353
return hparams
354354

355355

356+
@registry.register_hparams
357+
def transformer_parsing_ice():
358+
"""Hparams for parsing Icelandic text."""
359+
hparams = transformer_base_single_gpu()
360+
hparams.batch_size = 4096
361+
hparams.shared_embedding_and_softmax_weights = int(False)
362+
return hparams
363+
364+
356365
@registry.register_hparams
357366
def transformer_parsing_big():
358367
"""HParams for parsing on wsj semi-supervised."""

0 commit comments

Comments
 (0)