Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit af235c1

Browse files
authored
Merge pull request #66 from vthorsteinsson/fix-tokens
Unicode in SubwordTextEncoder
2 parents c4c768f + bc75385 commit af235c1

File tree

5 files changed

+152
-124
lines changed

5 files changed

+152
-124
lines changed

tensor2tensor/data_generators/generator_utils.py

100644100755
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,12 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
242242

243243
# For some datasets a second extraction is necessary.
244244
if ".gz" in lang_file:
245-
tf.logging.info("Unpacking subdirectory %s" % filepath)
246245
new_filepath = os.path.join(tmp_dir, lang_file[:-3])
247-
gunzip_file(filepath, new_filepath)
246+
if os.path.exists(new_filepath):
247+
tf.logging.info("Subdirectory %s already exists, skipping unpacking" % filepath)
248+
else:
249+
tf.logging.info("Unpacking subdirectory %s" % filepath)
250+
gunzip_file(filepath, new_filepath)
248251
filepath = new_filepath
249252

250253
# Use Tokenizer to count the word occurrences.
@@ -258,7 +261,8 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
258261
_ = tokenizer.encode(line)
259262

260263
vocab = SubwordTextEncoder.build_to_target_size(
261-
vocab_size, tokenizer.token_counts, vocab_filepath, 1, 1e3)
264+
vocab_size, tokenizer.token_counts, 1, 1e3)
265+
vocab.store_to_file(vocab_filepath)
262266
return vocab
263267

264268

tensor2tensor/data_generators/snli.py

100644100755
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ def _get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size):
136136
if tf.gfile.Exists(vocab_filepath):
137137
gs = text_encoder.SubwordTextEncoder(vocab_filepath)
138138
return gs
139-
else:
140-
example_file = os.path.join(tmp_dir, _EXAMPLES_FILE)
141-
gs = text_encoder.SubwordTextEncoder()
142-
token_counts = text_encoder.SubwordTextEncoder.get_token_counts(
143-
example_file, corpus_max_lines=1000000)
144-
gs = gs.build_to_target_size(
145-
vocab_size, token_counts, vocab_filepath, min_val=1, max_val=1e3)
146-
return gs
139+
example_file = os.path.join(tmp_dir, _EXAMPLES_FILE)
140+
gs = text_encoder.SubwordTextEncoder()
141+
token_counts = text_encoder.SubwordTextEncoder.get_token_counts(
142+
example_file, corpus_max_lines=1000000)
143+
gs = gs.build_to_target_size(
144+
vocab_size, token_counts, min_val=1, max_val=1e3)
145+
gs.store_to_file(vocab_filepath)
146+
return gs
147147

148148

149149
def snli_token_generator(tmp_dir, train, vocab_size):

tensor2tensor/data_generators/text_encoder.py

100644100755
Lines changed: 89 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ class SubwordTextEncoder(TextEncoder):
175175
"""
176176

177177
def __init__(self, filename=None, num_reserved_ids=2):
178-
"""Read from a file."""
179178
self._tokenizer = tokenizer.Tokenizer()
180179
if filename is not None:
180+
# Read from a file.
181181
self._load_from_file(filename)
182182

183183
super(SubwordTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
@@ -235,14 +235,13 @@ def _subtokens_to_tokens(self, subtokens):
235235

236236
def subtoken_to_subtoken_string(self, subtoken):
237237
"""Subtoken_String (string) corresponding to the given subtoken (id)."""
238-
if (subtoken >= 0 and subtoken < self.vocab_size and
239-
self._all_subtoken_strings[subtoken]):
240-
return self._all_subtoken_strings[subtoken]
241-
else:
242-
if 0 <= subtoken < self._num_reserved_ids:
243-
return '%s_' % RESERVED_TOKENS[subtoken]
244-
else:
245-
return 'ID%d_' % subtoken
238+
if 0 <= subtoken < self.vocab_size:
239+
subtoken_string = self._all_subtoken_strings[subtoken]
240+
if subtoken_string:
241+
return subtoken_string
242+
if 0 <= subtoken < self._num_reserved_ids:
243+
return '%s_' % RESERVED_TOKENS[subtoken]
244+
return 'ID%d_' % subtoken
246245

247246
def _escaped_token_to_subtokens(self, escaped_token):
248247
"""Converts an escaped token string to a list of subtokens.
@@ -262,21 +261,32 @@ def _escaped_token_to_subtokens(self, escaped_token):
262261
if subtoken != -1:
263262
break
264263
end -= 1
265-
ret.append(subtoken)
266264
if end > pos:
265+
ret.append(subtoken)
267266
pos = end
268267
else:
269-
# This kinda should not happen, but it does. Cop out by skipping the
270-
# nonexistent subtoken from the returned list.
271-
# print("Unable to find subtoken in string '{0}'".format(escaped_token))
268+
# No subtoken in the vocabulary matches escaped_token[pos].
269+
# This can happen if the token contains a Unicode character
270+
# that did not occur in the vocabulary training set.
271+
# The id self.vocab_size - 1 is decoded as Unicode uFFFD,
272+
# REPLACEMENT_CHARACTER.
273+
ret.append(self.vocab_size - 1)
274+
# Ensure that the outer loop continues
272275
pos += 1
273276
return ret
274277

278+
@classmethod
279+
def alphabet(cls, token_counts):
280+
"""Return the set of Unicode characters that appear in the tokens"""
281+
alphabet_set = set()
282+
for token in six.iterkeys(token_counts):
283+
alphabet_set |= set(token)
284+
return alphabet_set
285+
275286
@classmethod
276287
def build_to_target_size(cls,
277288
target_size,
278289
token_counts,
279-
store_filename,
280290
min_val,
281291
max_val,
282292
num_iterations=4):
@@ -296,50 +306,51 @@ def build_to_target_size(cls,
296306
Returns:
297307
a SubwordTextEncoder instance.
298308
"""
299-
present_count = (max_val + min_val) // 2
300-
tf.logging.info('Trying min_count %d' % present_count)
301-
subtokenizer = cls()
302-
subtokenizer.build_from_token_counts(token_counts, store_filename,
303-
present_count, num_iterations)
304-
305-
if min_val >= max_val or subtokenizer.vocab_size == target_size:
306-
return subtokenizer
307-
elif subtokenizer.vocab_size > target_size:
308-
other_subtokenizer = cls.build_to_target_size(
309-
target_size, token_counts, store_filename, present_count + 1, max_val,
310-
num_iterations)
311-
if (abs(other_subtokenizer.vocab_size - target_size) <
312-
abs(subtokenizer.vocab_size - target_size)):
313-
return other_subtokenizer
314-
else:
309+
310+
# Calculate the alphabet, i.e. the set of all Unicode characters
311+
# that appear in the tokens
312+
alphabet_set = cls.alphabet(token_counts)
313+
tf.logging.info('Alphabet contains %d characters' % len(alphabet_set))
314+
315+
def bisect(min_val, max_val):
316+
present_count = (max_val + min_val) // 2
317+
tf.logging.info('Trying min_count %d' % present_count)
318+
subtokenizer = cls()
319+
subtokenizer.build_from_token_counts(token_counts, alphabet_set,
320+
present_count, num_iterations)
321+
322+
if min_val >= max_val or subtokenizer.vocab_size == target_size:
315323
return subtokenizer
316-
else:
317-
other_subtokenizer = cls.build_to_target_size(
318-
target_size, token_counts, store_filename, min_val, present_count - 1,
319-
num_iterations)
324+
if subtokenizer.vocab_size > target_size:
325+
other_subtokenizer = bisect(present_count + 1, max_val)
326+
else:
327+
other_subtokenizer = bisect(min_val, present_count - 1)
320328
if (abs(other_subtokenizer.vocab_size - target_size) <
321329
abs(subtokenizer.vocab_size - target_size)):
322330
return other_subtokenizer
323331
else:
324332
return subtokenizer
325333

334+
return bisect(min_val, max_val)
335+
326336
def build_from_token_counts(self,
327337
token_counts,
328-
store_filename,
338+
alphabet_set,
329339
min_count,
330340
num_iterations=4):
331341
"""Train a SubwordTextEncoder based on a dictionary of word counts.
332342
333343
Args:
334-
token_counts: a dictionary of string to int.
335-
store_filename: a string - where to write the vocabulary.
344+
token_counts: a dictionary of Unicode strings to int.
345+
alphabet_set: the set of Unicode characters that appear in the tokens.
336346
min_count: an integer - discard subtokens with lower counts.
337347
num_iterations: an integer. how many iterations of refinement.
338348
"""
339349
# We build iteratively. On each iteration, we segment all the words,
340350
# then count the resulting potential subtokens, keeping the ones
341351
# with high enough counts for our new vocabulary.
342352
for i in xrange(num_iterations):
353+
tf.logging.info("Iteration {0}".format(i))
343354
counts = defaultdict(int)
344355
for token, count in six.iteritems(token_counts):
345356
escaped_token = self._escape_token(token)
@@ -353,39 +364,49 @@ def build_from_token_counts(self,
353364
starts = []
354365
for subtoken in subtokens:
355366
starts.append(pos)
356-
pos += len(self.subtoken_to_subtoken_string(subtoken))
367+
pos += len(self._all_subtoken_strings[subtoken])
357368
for start in starts:
358-
for end in xrange(start + 1, len(escaped_token)):
369+
for end in xrange(start + 1, len(escaped_token) + 1):
359370
subtoken_string = escaped_token[start:end]
360371
counts[subtoken_string] += count
361-
# array of lists of candidate subtoken strings, by length
372+
# Array of sets of candidate subtoken strings, by length
362373
len_to_subtoken_strings = []
363374
for subtoken_string, count in six.iteritems(counts):
364375
lsub = len(subtoken_string)
365-
# all subtoken strings of length 1 are included regardless of count
366-
if count < min_count and lsub != 1:
376+
# All subtoken strings of length 1 are automatically included
377+
# later, so we don't need to consider them here
378+
if count < min_count or lsub <= 1:
367379
continue
380+
# Add this subtoken string to its length set
368381
while len(len_to_subtoken_strings) <= lsub:
369-
len_to_subtoken_strings.append([])
370-
len_to_subtoken_strings[lsub].append(subtoken_string)
382+
len_to_subtoken_strings.append(set())
383+
len_to_subtoken_strings[lsub].add(subtoken_string)
371384
new_subtoken_strings = []
372385
# consider the candidates longest to shortest, so that if we accept
373386
# a longer subtoken string, we can decrement the counts of its prefixes.
374-
for subtoken_strings in len_to_subtoken_strings[::-1]:
387+
for subtoken_strings in reversed(len_to_subtoken_strings[2:]):
375388
for subtoken_string in subtoken_strings:
376389
count = counts[subtoken_string]
377-
if count < min_count and len(subtoken_string) != 1:
378-
# subtoken strings of length 1 are included regardless of count
390+
if count < min_count:
379391
continue
380-
new_subtoken_strings.append((-count, subtoken_string))
392+
new_subtoken_strings.append((count, subtoken_string))
381393
for l in xrange(1, len(subtoken_string)):
382394
counts[subtoken_string[:l]] -= count
383-
# Make sure to include the underscore as a subtoken string
384-
new_subtoken_strings.append((0, '_'))
385-
new_subtoken_strings.sort()
386-
self._init_from_list([''] * self._num_reserved_ids +
395+
# Sort what we've got so far in decreasing order by count
396+
new_subtoken_strings.sort(reverse = True)
397+
# Add the alphabet set at the end of the vocabulary list
398+
for char in alphabet_set:
399+
new_subtoken_strings.append((0, char))
400+
# Also include the Unicode REPLACEMENT CHARACTER to use
401+
# when encountering previously unseen Unicode characters
402+
# in the input (i.e. input external to the tokenizer training
403+
# set, which may thus contain characters not in the alphabet_set).
404+
# This must be the last entry in the subtoken vocabulary list.
405+
new_subtoken_strings.append((0, u'\uFFFD'))
406+
# Now we have a candidate vocabulary
407+
self._init_from_list([u''] * self._num_reserved_ids +
387408
[p[1] for p in new_subtoken_strings])
388-
print('vocab_size = %d' % self.vocab_size)
409+
tf.logging.info('vocab_size = %d' % self.vocab_size)
389410

390411
original = 'This sentence was encoded by the SubwordTextEncoder.'
391412
encoded = self.encode(original)
@@ -394,33 +415,33 @@ def build_from_token_counts(self,
394415
decoded = self.decode(encoded)
395416
print(decoded)
396417
assert decoded == original
397-
self._store_to_file(store_filename)
418+
419+
def dump(self):
420+
""" Debugging dump of the current subtoken vocabulary """
421+
subtoken_strings = [(i, s) for s, i in six.iteritems(self._subtoken_string_to_id)]
422+
print(u", ".join(u"{0} : '{1}'".format(i, s) for i, s in sorted(subtoken_strings)))
398423

399424
def _init_from_list(self, subtoken_strings):
400425
"""Initialize from a list of subtoken strings."""
401426
self._all_subtoken_strings = subtoken_strings
402-
self._subtoken_string_to_id = {}
403-
for i in xrange(len(subtoken_strings)):
404-
subtoken_string = subtoken_strings[i]
405-
if subtoken_string:
406-
self._subtoken_string_to_id[subtoken_string] = i
427+
self._subtoken_string_to_id = { s : i for i, s in enumerate(subtoken_strings) if s }
407428

408429
def _load_from_file(self, filename):
409430
"""Load from a file."""
410431
subtoken_strings = []
411432
with tf.gfile.Open(filename) as f:
412433
for line in f:
413434
if six.PY2:
414-
subtoken_strings.append(line.strip()[1:-1].decode('string-escape'))
435+
subtoken_strings.append(line.strip()[1:-1].decode('utf-8'))
415436
else:
416437
subtoken_strings.append(line.strip()[1:-1])
417438
self._init_from_list(subtoken_strings)
418439

419-
def _store_to_file(self, filename):
440+
def store_to_file(self, filename):
420441
with tf.gfile.Open(filename, 'w') as f:
421442
for subtoken_string in self._all_subtoken_strings:
422443
if six.PY2:
423-
f.write('\'' + subtoken_string.encode('string-escape') + '\'\n')
444+
f.write('\'' + subtoken_string.encode('utf-8') + '\'\n')
424445
else:
425446
f.write('\'' + subtoken_string + '\'\n')
426447

@@ -437,43 +458,26 @@ def _escape_token(self, token):
437458
def _unescape_token(self, escaped_token):
438459
r"""Remove '_' from end, then translate '\\'->'\' and '\u'->'_'.
439460
440-
TODO(noam): There must be some better way to do this with regexps.
441-
442461
Args:
443462
escaped_token: a string
444463
Returns:
445464
token: a string
446465
"""
447466
assert escaped_token[-1] == '_'
448-
escaped_token = escaped_token[:-1]
449-
if '\\' not in escaped_token:
450-
return escaped_token
451-
ret = ''
452-
pos = 0
453-
while pos < len(escaped_token):
454-
if escaped_token[pos] == '\\' and pos + 1 < len(escaped_token):
455-
if escaped_token[pos + 1] == 'u':
456-
ret += '_'
457-
else:
458-
ret += escaped_token[pos + 1]
459-
pos += 1
460-
pos += 1
461-
return ret
467+
return escaped_token[:-1].replace('\\u', '_').replace('\\\\', '\\')
462468

463469
@classmethod
464470
def get_token_counts(cls, text_filepattern, corpus_max_lines):
465-
"""Read the corpus and compute a dictionary of word counts."""
471+
"""Read the corpus and compute a dictionary of token counts."""
466472
tok = tokenizer.Tokenizer()
467-
token_counts = {}
468473
lines_read = 0
469474
filenames = tf.gfile.Glob(text_filepattern)
470475
for text_filename in filenames:
471476
with tf.gfile.Open(text_filename) as f:
472477
for line in f:
473-
tokens = tok.encode(line.strip())
474-
for t in tokens:
475-
token_counts[t] = token_counts.get(t, 0) + 1
478+
# The tokenizer updates token_counts in encode()
479+
tok.encode(line.strip())
476480
lines_read += 1
477481
if corpus_max_lines > 0 and lines_read > corpus_max_lines:
478-
return token_counts
479-
return token_counts
482+
return tok.token_counts
483+
return tok.token_counts

tensor2tensor/data_generators/text_encoder_build_subword.py

100644100755
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,11 @@ def main(unused_argv):
5959
raise ValueError('Must provide --corpus_filepattern')
6060
token_counts = text_encoder.SubwordTextEncoder.get_token_counts(
6161
FLAGS.corpus_filepattern, FLAGS.corpus_max_lines)
62-
gs.build_from_token_counts(token_counts, FLAGS.output_fn, FLAGS.min_count,
62+
alphabet_set = SubwordTextEncoder.alphabet(token_counts)
63+
gs.build_from_token_counts(token_counts, alphabet_set,
64+
FLAGS.min_count,
6365
FLAGS.num_iterations)
66+
gs.store_to_file(FLAGS.output_fn)
6467

6568

6669
if __name__ == '__main__':

0 commit comments

Comments
 (0)