Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit cff9f43

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Allow building a subword vocab from a word vocab file and add tests.
PiperOrigin-RevId: 163250427
1 parent 28eb48f commit cff9f43

File tree

9 files changed

+387
-146
lines changed

9 files changed

+387
-146
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
One morning I shot an elephant in my pajamas. How he got in my pajamas, I don't
2+
know.
3+
4+
Groucho Marx
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
I haven't slept for 10 days... because that would be too long.
2+
3+
Mitch Hedberg
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
lollipop,8
2+
reverberated,12
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
kattywampus,11
2+
balderdash,10
3+
jiggery-pokery,14

tensor2tensor/data_generators/text_encoder.py

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030
# Dependency imports
3131

3232
import six
33-
from six.moves import xrange # pylint: disable=redefined-builtin
3433
from tensor2tensor.data_generators import tokenizer
3534

3635
import tensorflow as tf
3736

37+
xrange = six.moves.xrange # pylint: disable=redefined-builtin
3838

3939
# Reserved tokens for things like padding and EOS symbols.
4040
PAD = "<pad>"
@@ -295,7 +295,7 @@ def encode(self, raw_text):
295295
Returns:
296296
a list of integers in the range [0, vocab_size)
297297
"""
298-
return self._tokens_to_subtokens(tokenizer.encode(
298+
return self._tokens_to_subtoken_ids(tokenizer.encode(
299299
native_to_unicode(raw_text)))
300300

301301
def decode(self, subtokens):
@@ -307,14 +307,14 @@ def decode(self, subtokens):
307307
a native string
308308
"""
309309
return unicode_to_native(tokenizer.decode(
310-
self._subtokens_to_tokens(subtokens)))
310+
self._subtoken_ids_to_tokens(subtokens)))
311311

312312
@property
313313
def vocab_size(self):
314314
"""The subtoken vocabulary size."""
315315
return len(self._all_subtoken_strings)
316316

317-
def _tokens_to_subtokens(self, tokens):
317+
def _tokens_to_subtoken_ids(self, tokens):
318318
"""Converts a list of tokens to a list of subtoken ids.
319319
320320
Args:
@@ -324,11 +324,11 @@ def _tokens_to_subtokens(self, tokens):
324324
"""
325325
ret = []
326326
for token in tokens:
327-
ret.extend(self._escaped_token_to_subtokens(
327+
ret.extend(self._escaped_token_to_subtoken_ids(
328328
_escape_token(token, self._alphabet)))
329329
return ret
330330

331-
def _subtokens_to_tokens(self, subtokens):
331+
def _subtoken_ids_to_tokens(self, subtokens):
332332
"""Converts a list of subtoken ids to a list of tokens.
333333
334334
Args:
@@ -337,45 +337,58 @@ def _subtokens_to_tokens(self, subtokens):
337337
a list of strings.
338338
"""
339339
concatenated = "".join(
340-
[self._subtoken_to_subtoken_string(s) for s in subtokens])
340+
[self._subtoken_id_to_subtoken_string(s) for s in subtokens])
341341
split = concatenated.split("_")
342342
return [_unescape_token(t + "_") for t in split if t]
343343

344-
def _subtoken_to_subtoken_string(self, subtoken):
345-
"""Subtoken_String (string) corresponding to the given subtoken (id)."""
344+
def _subtoken_id_to_subtoken_string(self, subtoken):
345+
"""Converts a subtoken integer ID to a subtoken string."""
346346
if 0 <= subtoken < self.vocab_size:
347347
return self._all_subtoken_strings[subtoken]
348348
return u""
349349

350-
def _escaped_token_to_subtokens(self, escaped_token):
351-
"""Converts an escaped token string to a list of subtokens.
350+
def _escaped_token_to_subtoken_strings(self, escaped_token):
351+
"""Converts an escaped token string to a list of subtoken strings.
352352
353353
Args:
354-
escaped_token: an escaped token
354+
escaped_token: An escaped token as a unicode string.
355355
Returns:
356-
a list of one or more integers.
356+
A list of subtokens as unicode strings.
357357
"""
358+
# NOTE: This algorithm is greedy; it won't necessarily produce the "best"
359+
# list of subtokens.
358360
ret = []
359-
pos = 0
360-
lesc = len(escaped_token)
361-
while pos < lesc:
362-
end = min(lesc, pos + self._max_subtoken_len)
363-
while end > pos:
364-
subtoken_id = self._subtoken_string_to_id.get(escaped_token[pos:end])
365-
if subtoken_id is not None:
361+
start = 0
362+
token_len = len(escaped_token)
363+
while start < token_len:
364+
for end in xrange(
365+
min(token_len, start + self._max_subtoken_len), start, -1):
366+
subtoken = escaped_token[start:end]
367+
if subtoken in self._subtoken_string_to_id:
368+
ret.append(subtoken)
369+
start = end
366370
break
367-
end -= 1
368371

369-
# If there is no possible encoding of the escaped token then one of the
370-
# characters in the token is not in the alphabet. This should be
371-
# impossible and would be indicative of a bug.
372-
assert subtoken_id is not None
373-
374-
ret.append(subtoken_id)
375-
pos = end
372+
else: # Did not break
373+
# If there is no possible encoding of the escaped token then one of the
374+
# characters in the token is not in the alphabet. This should be
375+
# impossible and would be indicative of a bug.
376+
assert False, "Token substring not found in subtoken vocabulary."
376377

377378
return ret
378379

380+
def _escaped_token_to_subtoken_ids(self, escaped_token):
381+
"""Converts an escaped token string to a list of subtoken IDs.
382+
383+
Args:
384+
escaped_token: An escaped token as a unicode string.
385+
Returns:
386+
A list of subtoken IDs as integers.
387+
"""
388+
return [
389+
self._subtoken_string_to_id[subtoken]
390+
for subtoken in self._escaped_token_to_subtoken_strings(escaped_token)]
391+
379392
@classmethod
380393
def build_to_target_size(cls,
381394
target_size,
@@ -460,55 +473,51 @@ def build_from_token_counts(self,
460473
min_count = 1
461474
for i in xrange(num_iterations):
462475
tf.logging.info("Iteration {0}".format(i))
463-
counts = collections.defaultdict(int)
476+
477+
# Collect all substrings of the encoded token that break along current
478+
# subtoken boundaries.
479+
subtoken_counts = collections.defaultdict(int)
464480
for token, count in six.iteritems(token_counts):
465481
escaped_token = _escape_token(token, self._alphabet)
466-
# we will count all tails of the escaped_token, starting from boundaries
467-
# determined by our current segmentation.
468-
if i == 0:
469-
starts = six.moves.range(len(escaped_token))
470-
else:
471-
subtokens = self._escaped_token_to_subtokens(escaped_token)
472-
pos = 0
473-
starts = []
474-
for subtoken in subtokens:
475-
starts.append(pos)
476-
pos += len(self._all_subtoken_strings[subtoken])
477-
for start in starts:
482+
subtokens = self._escaped_token_to_subtoken_strings(escaped_token)
483+
start = 0
484+
for subtoken in subtokens:
478485
for end in xrange(start + 1, len(escaped_token) + 1):
479-
subtoken_string = escaped_token[start:end]
480-
counts[subtoken_string] += count
481-
# Array of sets of candidate subtoken strings, by length
486+
new_subtoken = escaped_token[start:end]
487+
subtoken_counts[new_subtoken] += count
488+
start += len(subtoken)
489+
490+
# Array of sets of candidate subtoken strings, by length.
482491
len_to_subtoken_strings = []
483-
for subtoken_string, count in six.iteritems(counts):
492+
for subtoken_string, count in six.iteritems(subtoken_counts):
484493
lsub = len(subtoken_string)
485-
# Always include all the alphabet characters or some strings will
486-
# be unencodeable.
487-
if count >= min_count or subtoken_string in self._alphabet:
488-
# Add this subtoken string to its length set
494+
if count >= min_count:
489495
while len(len_to_subtoken_strings) <= lsub:
490496
len_to_subtoken_strings.append(set())
491497
len_to_subtoken_strings[lsub].add(subtoken_string)
492-
new_subtoken_strings = []
498+
493499
# Consider the candidates longest to shortest, so that if we accept
494500
# a longer subtoken string, we can decrement the counts of its prefixes.
501+
new_subtoken_strings = []
495502
for lsub in xrange(len(len_to_subtoken_strings)-1, 0, -1):
496503
subtoken_strings = len_to_subtoken_strings[lsub]
497504
for subtoken_string in subtoken_strings:
498-
count = counts[subtoken_string]
499-
if count >= min_count or subtoken_string in self._alphabet:
500-
# Exclude alphabet tokens here, as they must be included later
505+
count = subtoken_counts[subtoken_string]
506+
if count >= min_count:
507+
# Exclude alphabet tokens here, as they must be included later,
501508
# explicitly, regardless of count.
502509
if subtoken_string not in self._alphabet:
503510
new_subtoken_strings.append((count, subtoken_string))
504511
for l in xrange(1, lsub):
505-
counts[subtoken_string[:l]] -= count
512+
subtoken_counts[subtoken_string[:l]] -= count
513+
514+
# Include the alphabet explicitly to guarantee all strings are encodable.
515+
new_subtoken_strings.extend(
516+
(subtoken_counts.get(a, 0), a) for a in self._alphabet)
506517
new_subtoken_strings.sort(reverse=True)
507518

508-
# Reinitialize to the candidate vocabulary, including the alphabet
509-
# explicitly as the highest priority.
519+
# Reinitialize to the candidate vocabulary.
510520
self._init_subtokens_from_list(
511-
list(self._alphabet) +
512521
[subtoken for _, subtoken in new_subtoken_strings],
513522
reserved=num_reserved_ids)
514523
tf.logging.info("vocab_size = %d" % self.vocab_size)

tensor2tensor/data_generators/text_encoder_build_subword.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@
3939

4040
import tensorflow as tf
4141

42-
tf.app.flags.DEFINE_string('output_fn', '/tmp/my.subword_text_encoder',
42+
tf.app.flags.DEFINE_string('output_filename', '/tmp/my.subword_text_encoder',
4343
'where to store the SubwordTextEncoder')
4444
tf.app.flags.DEFINE_string('corpus_filepattern', '',
4545
'Corpus of one or more text files')
46+
tf.app.flags.DEFINE_string('vocab_filepattern', '',
47+
'One or more vocabulary files '
48+
'(one word per line as "word,count")')
4649
tf.app.flags.DEFINE_integer('min_count', 5, 'Minimum subtoken count in corpus')
4750
tf.app.flags.DEFINE_integer('corpus_max_lines', 10000,
4851
'How many lines of corpus to read')
@@ -52,16 +55,27 @@
5255

5356

5457
def main(unused_argv):
55-
gs = text_encoder.SubwordTextEncoder()
56-
if not FLAGS.corpus_filepattern:
57-
raise ValueError('Must provide --corpus_filepattern')
58-
token_counts = tokenizer.corpus_token_counts(
59-
FLAGS.corpus_filepattern, FLAGS.corpus_max_lines,
60-
split_on_newlines=FLAGS.split_on_newlines)
61-
gs.build_from_token_counts(token_counts,
62-
FLAGS.min_count,
63-
FLAGS.num_iterations)
64-
gs.store_to_file(FLAGS.output_fn)
58+
if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern:
59+
raise ValueError(
60+
'Must only provide one of --corpus_filepattern or --vocab_filepattern')
61+
62+
elif FLAGS.corpus_filepattern:
63+
token_counts = tokenizer.corpus_token_counts(
64+
FLAGS.corpus_filepattern, FLAGS.corpus_max_lines,
65+
split_on_newlines=FLAGS.split_on_newlines)
66+
67+
elif FLAGS.vocab_filepattern:
68+
token_counts = tokenizer.vocab_token_counts(
69+
FLAGS.vocab_filepattern, FLAGS.corpus_max_lines)
70+
71+
else:
72+
raise ValueError(
73+
'Must provide one of --corpus_filepattern or --vocab_filepattern')
74+
75+
encoder = text_encoder.SubwordTextEncoder()
76+
encoder.build_from_token_counts(
77+
token_counts, FLAGS.min_count, FLAGS.num_iterations)
78+
encoder.store_to_file(FLAGS.output_fn)
6579

6680

6781
if __name__ == '__main__':

0 commit comments

Comments
 (0)