2828# Dependency imports
2929
3030import six
31+ from six import PY2
3132from six.moves import xrange # pylint: disable=redefined-builtin
3233from tensor2tensor.data_generators import tokenizer
3334
3435import tensorflow as tf
3536
37+
38+ # Conversion between Unicode and UTF-8, if required (on Python2)
39+ _native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s)
40+
41+
42+ _unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s)
43+
44+
3645# Reserved tokens for things like padding and EOS symbols.
3746PAD = "<pad>"
3847EOS = "<EOS>"
@@ -162,15 +171,36 @@ def _load_vocab_from_file(self, filename):
162171
163172
164173class SubwordTextEncoder(TextEncoder):
165- """Class for breaking tokens into subtokens .
174+ """Class for invertibly encoding text using a limited vocabulary .
166175
167- Invertibly encodes a string as a sequence of subtokens from a limited
176+ Invertibly encodes a native string as a sequence of subtokens from a limited
168177 vocabulary.
169178
170179 A SubwordTextEncoder is built from a corpus (so it is tailored to the text in
171180 the corpus), and stored to a file. See text_encoder_build_subword.py.
172181
173182 It can then be loaded and used to encode/decode any text.
183+
184+ Encoding has four phases:
185+
186+ 1. Tokenize into a list of tokens. Each token is a unicode string of either
187+ all alphanumeric characters or all non-alphanumeric characters. We drop
188+ tokens consisting of a single space that are between two alphanumeric
189+ tokens.
190+
191+ 2. Escape each token. This escapes away special and out-of-vocabulary
192+ characters, and makes sure that each token ends with an underscore, and
193+ has no other underscores.
194+
195+ 3. Represent each escaped token as a the concatenation of a list of subtokens
196+ from the limited vocabulary. Subtoken selection is done greedily from
197+ beginning to end. That is, we construct the list in order, always picking
198+ the longest subtoken in our vocabulary that matches a prefix of the
199+ remaining portion of the encoded token.
200+
201+ 4. Concatenate these lists. This concatenation is invertible due to the
202+ fact that the trailing underscores indicate when one list is finished.
203+
174204 """
175205
176206 def __init__(self, filename=None, num_reserved_ids=2):
@@ -182,24 +212,26 @@ def __init__(self, filename=None, num_reserved_ids=2):
182212 super(SubwordTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
183213
184214 def encode(self, raw_text):
185- """Converts a string to a list of subtoken ids.
215+ """Converts a native string to a list of subtoken ids.
186216
187217 Args:
188- raw_text: a string.
218+ raw_text: a native string.
189219 Returns:
190220 a list of integers in the range [0, vocab_size)
191221 """
192- return self._tokens_to_subtokens(self._tokenizer.encode(raw_text))
222+ return self._tokens_to_subtokens(self._tokenizer.encode(
223+ _native_to_unicode(raw_text)))
193224
194225 def decode(self, subtokens):
195- """Converts a sequence of subtoken ids to a string.
226+ """Converts a sequence of subtoken ids to a native string.
196227
197228 Args:
198229 subtokens: a list of integers in the range [0, vocab_size)
199230 Returns:
200- a string
231+ a native string
201232 """
202- return self._tokenizer.decode(self._subtokens_to_tokens(subtokens))
233+ return _unicode_to_native(self._tokenizer.decode(
234+ self._subtokens_to_tokens(subtokens)))
203235
204236 @property
205237 def vocab_size(self):
@@ -239,8 +271,8 @@ def subtoken_to_subtoken_string(self, subtoken):
239271 if subtoken_string:
240272 return subtoken_string
241273 if 0 <= subtoken < self._num_reserved_ids:
242- return "%s_" % RESERVED_TOKENS[subtoken]
243- return "ID%d_" % subtoken
274+ return u "%s_" % RESERVED_TOKENS[subtoken]
275+ return u "ID%d_" % subtoken
244276
245277 def _escaped_token_to_subtokens(self, escaped_token):
246278 """Converts an escaped token string to a list of subtokens.
@@ -260,27 +292,11 @@ def _escaped_token_to_subtokens(self, escaped_token):
260292 if subtoken != -1:
261293 break
262294 end -= 1
263- if end > pos:
264- ret.append(subtoken)
265- pos = end
266- else:
267- # No subtoken in the vocabulary matches escaped_token[pos].
268- # This can happen if the token contains a Unicode character
269- # that did not occur in the vocabulary training set.
270- # The id self.vocab_size - 1 is decoded as Unicode uFFFD,
271- # REPLACEMENT_CHARACTER.
272- ret.append(self.vocab_size - 1)
273- # Ensure that the outer loop continues
274- pos += 1
275- return ret
295+ assert end > pos
296+ ret.append(subtoken)
297+ pos = end
276298
277- @classmethod
278- def alphabet(cls, token_counts):
279- """Return the set of Unicode characters that appear in the tokens."""
280- alphabet_set = set()
281- for token in six.iterkeys(token_counts):
282- alphabet_set |= set(token)
283- return alphabet_set
299+ return ret
284300
285301 @classmethod
286302 def build_to_target_size(cls,
@@ -304,17 +320,12 @@ def build_to_target_size(cls,
304320 Returns:
305321 a SubwordTextEncoder instance.
306322 """
307- # Calculate the alphabet, i.e. the set of all Unicode characters
308- # that appear in the tokens.
309- alphabet_set = cls.alphabet(token_counts)
310- tf.logging.info("Alphabet contains %d characters" % len(alphabet_set))
311-
312323 def bisect(min_val, max_val):
313324 """Bisection to find the right size."""
314325 present_count = (max_val + min_val) // 2
315326 tf.logging.info("Trying min_count %d" % present_count)
316327 subtokenizer = cls()
317- subtokenizer.build_from_token_counts(token_counts, alphabet_set,
328+ subtokenizer.build_from_token_counts(token_counts,
318329 present_count, num_iterations)
319330 if min_val >= max_val or subtokenizer.vocab_size == target_size:
320331 return subtokenizer
@@ -333,17 +344,29 @@ def bisect(min_val, max_val):
333344
334345 def build_from_token_counts(self,
335346 token_counts,
336- alphabet_set,
337347 min_count,
338348 num_iterations=4):
339349 """Train a SubwordTextEncoder based on a dictionary of word counts.
340350
341351 Args:
342352 token_counts: a dictionary of Unicode strings to int.
343- alphabet_set: the set of Unicode characters that appear in the tokens.
344353 min_count: an integer - discard subtokens with lower counts.
345354 num_iterations: an integer. how many iterations of refinement.
346355 """
356+ # first determine the alphabet to include all characters with count at
357+ # least min_count in the dataset.
358+ char_counts = defaultdict(int)
359+ for token, count in six.iteritems(token_counts):
360+ for c in token:
361+ char_counts[c] += count
362+ self._alphabet = set()
363+ for c, count in six.iteritems(char_counts):
364+ if count >= min_count:
365+ self._alphabet.add(c)
366+ # Make sure all characters needed for escaping are included
367+ for c in u"\\_;0123456789":
368+ self._alphabet.add(c)
369+
347370 # We build iteratively. On each iteration, we segment all the words,
348371 # then count the resulting potential subtokens, keeping the ones
349372 # with high enough counts for our new vocabulary.
@@ -367,43 +390,36 @@ def build_from_token_counts(self,
367390 for end in xrange(start + 1, len(escaped_token) + 1):
368391 subtoken_string = escaped_token[start:end]
369392 counts[subtoken_string] += count
393+ # Make sure all characters needed for escaping are included
394+ for c in self._alphabet:
395+ counts[c] += min_count
370396 # Array of sets of candidate subtoken strings, by length
371397 len_to_subtoken_strings = []
372398 for subtoken_string, count in six.iteritems(counts):
373399 lsub = len(subtoken_string)
374- # All subtoken strings of length 1 are automatically included
375- # later, so we don't need to consider them here
376- if count < min_count or lsub <= 1:
377- continue
378- # Add this subtoken string to its length set
379- while len(len_to_subtoken_strings) <= lsub:
380- len_to_subtoken_strings.append(set())
381- len_to_subtoken_strings[lsub].add(subtoken_string)
400+ if count >= min_count:
401+ # Add this subtoken string to its length set
402+ while len(len_to_subtoken_strings) <= lsub:
403+ len_to_subtoken_strings.append(set())
404+ len_to_subtoken_strings[lsub].add(subtoken_string)
382405 new_subtoken_strings = []
383406 # consider the candidates longest to shortest, so that if we accept
384407 # a longer subtoken string, we can decrement the counts of its prefixes.
385- for subtoken_strings in reversed(len_to_subtoken_strings[2:]):
408+ for lsub in reversed(range(1, len(len_to_subtoken_strings))):
409+ subtoken_strings = len_to_subtoken_strings[lsub]
386410 for subtoken_string in subtoken_strings:
387411 count = counts[subtoken_string]
388- if count < min_count:
389- continue
390- new_subtoken_strings.append((count, subtoken_string))
391- for l in xrange(1, len(subtoken_string)):
392- counts[subtoken_string[:l]] -= count
393- # Sort what we've got so far in decreasing order by count
412+ if count >= min_count:
413+ new_subtoken_strings.append((count, subtoken_string))
414+ for l in xrange(1, lsub):
415+ counts[subtoken_string[:l]] -= count
416+ # Sort in decreasing order by count
394417 new_subtoken_strings.sort(reverse=True)
395- # Add the alphabet set at the end of the vocabulary list
396- for char in alphabet_set:
397- new_subtoken_strings.append((0, char))
398- # Also include the Unicode REPLACEMENT CHARACTER to use
399- # when encountering previously unseen Unicode characters
400- # in the input (i.e. input external to the tokenizer training
401- # set, which may thus contain characters not in the alphabet_set).
402- # This must be the last entry in the subtoken vocabulary list.
403- new_subtoken_strings.append((0, u"\uFFFD"))
404418 # Now we have a candidate vocabulary
419+ old_alphabet = self._alphabet
405420 self._init_from_list([u""] * self._num_reserved_ids +
406421 [p[1] for p in new_subtoken_strings])
422+ assert old_alphabet == self._alphabet
407423 tf.logging.info("vocab_size = %d" % self.vocab_size)
408424
409425 original = "This sentence was encoded by the SubwordTextEncoder."
@@ -426,46 +442,77 @@ def _init_from_list(self, subtoken_strings):
426442 self._all_subtoken_strings = subtoken_strings
427443 self._subtoken_string_to_id = {
428444 s: i for i, s in enumerate(subtoken_strings) if s}
445+ self._alphabet = set([c for c in subtoken_strings if len(c) == 1])
429446
430447 def _load_from_file(self, filename):
431448 """Load from a file."""
432449 subtoken_strings = []
433450 with tf.gfile.Open(filename) as f:
434451 for line in f:
435- if six.PY2:
436- subtoken_strings.append(line.strip()[1:-1].decode("utf-8"))
437- else:
438- subtoken_strings.append(line.strip()[1:-1])
452+ subtoken_strings.append(_native_to_unicode(line.strip()[1:-1]))
439453 self._init_from_list(subtoken_strings)
440454
441455 def store_to_file(self, filename):
442456 with tf.gfile.Open(filename, "w") as f:
443457 for subtoken_string in self._all_subtoken_strings:
444- if six.PY2:
445- f.write("'" + subtoken_string.encode("utf-8") + "'\n")
446- else:
447- f.write("'" + subtoken_string + "'\n")
458+ f.write("'" + _unicode_to_native(subtoken_string) + "'\n")
448459
449460 def _escape_token(self, token):
450- r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
461+ r"""Escape away underscores and OOV characters and append '_'.
462+
463+ This allows the token to be experessed as the concatenation of a list
464+ of subtokens from the vocabulary. The underscore acts as a sentinel
465+ which allows us to invertibly concatenate multiple such lists.
451466
452467 Args:
453- token: a string
468+ token: a unicode string
454469 Returns:
455- escaped_token: a string
470+ escaped_token: a unicode string
456471 """
457- return token.replace("\\", "\\\\").replace("_", "\\u") + "_"
472+ token = token.replace("\\", "\\\\").replace("_", "\\u") + "_"
473+ ret = u""
474+ for c in token:
475+ if c in self._alphabet:
476+ ret += c
477+ else:
478+ ret += u"\\%d;" % ord(c)
479+ return ret
458480
459481 def _unescape_token(self, escaped_token):
460- r"""Remove '_' from end, then translate '\\'->'\' and '\u'->'_' .
482+ r"""Inverse of _escape_token() .
461483
462484 Args:
463- escaped_token: a string
485+ escaped_token: a unicode string
464486 Returns:
465- token: a string
487+ token: a unicode string
466488 """
467- assert escaped_token[-1] == "_"
468- return escaped_token[:-1].replace("\\u", "_").replace("\\\\", "\\")
489+ ret = u""
490+ escaped_token = escaped_token[:-1]
491+ pos = 0
492+ while pos < len(escaped_token):
493+ c = escaped_token[pos]
494+ if c == "\\":
495+ pos += 1
496+ c = escaped_token[pos]
497+ if c == u"u":
498+ ret += u"_"
499+ pos += 1
500+ elif c == "\\":
501+ ret += u"_"
502+ pos += 1
503+ else:
504+ semicolon_pos = escaped_token.find(u";", pos)
505+ if semicolon_pos == -1:
506+ continue
507+ try:
508+ ret += unichr(int(escaped_token[pos:semicolon_pos]))
509+ pos = semicolon_pos + 1
510+ except (ValueError, OverflowError) as _:
511+ pass
512+ else:
513+ ret += c
514+ pos += 1
515+ return ret
469516
470517 @classmethod
471518 def get_token_counts(cls, text_filepattern, corpus_max_lines):
@@ -477,7 +524,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines):
477524 with tf.gfile.Open(text_filename) as f:
478525 for line in f:
479526 # The tokenizer updates token_counts in encode()
480- tok.encode(line.strip())
527+ tok.encode(_native_to_unicode( line.strip() ))
481528 lines_read += 1
482529 if corpus_max_lines > 0 and lines_read > corpus_max_lines:
483530 return tok.token_counts
0 commit comments