3030# Dependency imports
3131
3232import six
33- from six .moves import xrange # pylint: disable=redefined-builtin
3433from tensor2tensor .data_generators import tokenizer
3534
3635import tensorflow as tf
3736
37+ xrange = six .moves .xrange # pylint: disable=redefined-builtin
3838
3939# Reserved tokens for things like padding and EOS symbols.
4040PAD = "<pad>"
@@ -295,7 +295,7 @@ def encode(self, raw_text):
295295 Returns:
296296 a list of integers in the range [0, vocab_size)
297297 """
298- return self ._tokens_to_subtokens (tokenizer .encode (
298+ return self ._tokens_to_subtoken_ids (tokenizer .encode (
299299 native_to_unicode (raw_text )))
300300
301301 def decode (self , subtokens ):
@@ -307,14 +307,14 @@ def decode(self, subtokens):
307307 a native string
308308 """
309309 return unicode_to_native (tokenizer .decode (
310- self ._subtokens_to_tokens (subtokens )))
310+ self ._subtoken_ids_to_tokens (subtokens )))
311311
312312 @property
313313 def vocab_size (self ):
314314 """The subtoken vocabulary size."""
315315 return len (self ._all_subtoken_strings )
316316
317- def _tokens_to_subtokens (self , tokens ):
317+ def _tokens_to_subtoken_ids (self , tokens ):
318318 """Converts a list of tokens to a list of subtoken ids.
319319
320320 Args:
@@ -324,11 +324,11 @@ def _tokens_to_subtokens(self, tokens):
324324 """
325325 ret = []
326326 for token in tokens :
327- ret .extend (self ._escaped_token_to_subtokens (
327+ ret .extend (self ._escaped_token_to_subtoken_ids (
328328 _escape_token (token , self ._alphabet )))
329329 return ret
330330
331- def _subtokens_to_tokens (self , subtokens ):
331+ def _subtoken_ids_to_tokens (self , subtokens ):
332332 """Converts a list of subtoken ids to a list of tokens.
333333
334334 Args:
@@ -337,45 +337,58 @@ def _subtokens_to_tokens(self, subtokens):
337337 a list of strings.
338338 """
339339 concatenated = "" .join (
340- [self ._subtoken_to_subtoken_string (s ) for s in subtokens ])
340+ [self ._subtoken_id_to_subtoken_string (s ) for s in subtokens ])
341341 split = concatenated .split ("_" )
342342 return [_unescape_token (t + "_" ) for t in split if t ]
343343
344- def _subtoken_to_subtoken_string (self , subtoken ):
345- """Subtoken_String (string) corresponding to the given subtoken (id) ."""
344+ def _subtoken_id_to_subtoken_string (self , subtoken ):
345+ """Converts a subtoken integer ID to a subtoken string ."""
346346 if 0 <= subtoken < self .vocab_size :
347347 return self ._all_subtoken_strings [subtoken ]
348348 return u""
349349
350- def _escaped_token_to_subtokens (self , escaped_token ):
351- """Converts an escaped token string to a list of subtokens .
350+ def _escaped_token_to_subtoken_strings (self , escaped_token ):
351+ """Converts an escaped token string to a list of subtoken strings .
352352
353353 Args:
354- escaped_token: an escaped token
354+ escaped_token: An escaped token as a unicode string.
355355 Returns:
356- a list of one or more integers .
356+ A list of subtokens as unicode strings .
357357 """
358+ # NOTE: This algorithm is greedy; it won't necessarily produce the "best"
359+ # list of subtokens.
358360 ret = []
359- pos = 0
360- lesc = len (escaped_token )
361- while pos < lesc :
362- end = min (lesc , pos + self ._max_subtoken_len )
363- while end > pos :
364- subtoken_id = self ._subtoken_string_to_id .get (escaped_token [pos :end ])
365- if subtoken_id is not None :
361+ start = 0
362+ token_len = len (escaped_token )
363+ while start < token_len :
364+ for end in xrange (
365+ min (token_len , start + self ._max_subtoken_len ), start , - 1 ):
366+ subtoken = escaped_token [start :end ]
367+ if subtoken in self ._subtoken_string_to_id :
368+ ret .append (subtoken )
369+ start = end
366370 break
367- end -= 1
368371
369- # If there is no possible encoding of the escaped token then one of the
370- # characters in the token is not in the alphabet. This should be
371- # impossible and would be indicative of a bug.
372- assert subtoken_id is not None
373-
374- ret .append (subtoken_id )
375- pos = end
372+ else : # Did not break
373+ # If there is no possible encoding of the escaped token then one of the
374+ # characters in the token is not in the alphabet. This should be
375+ # impossible and would be indicative of a bug.
376+ assert False , "Token substring not found in subtoken vocabulary."
376377
377378 return ret
378379
380+ def _escaped_token_to_subtoken_ids (self , escaped_token ):
381+ """Converts an escaped token string to a list of subtoken IDs.
382+
383+ Args:
384+ escaped_token: An escaped token as a unicode string.
385+ Returns:
386+ A list of subtoken IDs as integers.
387+ """
388+ return [
389+ self ._subtoken_string_to_id [subtoken ]
390+ for subtoken in self ._escaped_token_to_subtoken_strings (escaped_token )]
391+
379392 @classmethod
380393 def build_to_target_size (cls ,
381394 target_size ,
@@ -460,55 +473,51 @@ def build_from_token_counts(self,
460473 min_count = 1
461474 for i in xrange (num_iterations ):
462475 tf .logging .info ("Iteration {0}" .format (i ))
463- counts = collections .defaultdict (int )
476+
477+ # Collect all substrings of the encoded token that break along current
478+ # subtoken boundaries.
479+ subtoken_counts = collections .defaultdict (int )
464480 for token , count in six .iteritems (token_counts ):
465481 escaped_token = _escape_token (token , self ._alphabet )
466- # we will count all tails of the escaped_token, starting from boundaries
467- # determined by our current segmentation.
468- if i == 0 :
469- starts = six .moves .range (len (escaped_token ))
470- else :
471- subtokens = self ._escaped_token_to_subtokens (escaped_token )
472- pos = 0
473- starts = []
474- for subtoken in subtokens :
475- starts .append (pos )
476- pos += len (self ._all_subtoken_strings [subtoken ])
477- for start in starts :
482+ subtokens = self ._escaped_token_to_subtoken_strings (escaped_token )
483+ start = 0
484+ for subtoken in subtokens :
478485 for end in xrange (start + 1 , len (escaped_token ) + 1 ):
479- subtoken_string = escaped_token [start :end ]
480- counts [subtoken_string ] += count
481- # Array of sets of candidate subtoken strings, by length
486+ new_subtoken = escaped_token [start :end ]
487+ subtoken_counts [new_subtoken ] += count
488+ start += len (subtoken )
489+
490+ # Array of sets of candidate subtoken strings, by length.
482491 len_to_subtoken_strings = []
483- for subtoken_string , count in six .iteritems (counts ):
492+ for subtoken_string , count in six .iteritems (subtoken_counts ):
484493 lsub = len (subtoken_string )
485- # Always include all the alphabet characters or some strings will
486- # be unencodeable.
487- if count >= min_count or subtoken_string in self ._alphabet :
488- # Add this subtoken string to its length set
494+ if count >= min_count :
489495 while len (len_to_subtoken_strings ) <= lsub :
490496 len_to_subtoken_strings .append (set ())
491497 len_to_subtoken_strings [lsub ].add (subtoken_string )
492- new_subtoken_strings = []
498+
493499 # Consider the candidates longest to shortest, so that if we accept
494500 # a longer subtoken string, we can decrement the counts of its prefixes.
501+ new_subtoken_strings = []
495502 for lsub in xrange (len (len_to_subtoken_strings )- 1 , 0 , - 1 ):
496503 subtoken_strings = len_to_subtoken_strings [lsub ]
497504 for subtoken_string in subtoken_strings :
498- count = counts [subtoken_string ]
499- if count >= min_count or subtoken_string in self . _alphabet :
500- # Exclude alphabet tokens here, as they must be included later
505+ count = subtoken_counts [subtoken_string ]
506+ if count >= min_count :
507+ # Exclude alphabet tokens here, as they must be included later,
501508 # explicitly, regardless of count.
502509 if subtoken_string not in self ._alphabet :
503510 new_subtoken_strings .append ((count , subtoken_string ))
504511 for l in xrange (1 , lsub ):
505- counts [subtoken_string [:l ]] -= count
512+ subtoken_counts [subtoken_string [:l ]] -= count
513+
514+ # Include the alphabet explicitly to guarantee all strings are encodable.
515+ new_subtoken_strings .extend (
516+ (subtoken_counts .get (a , 0 ), a ) for a in self ._alphabet )
506517 new_subtoken_strings .sort (reverse = True )
507518
508- # Reinitialize to the candidate vocabulary, including the alphabet
509- # explicitly as the highest priority.
519+ # Reinitialize to the candidate vocabulary.
510520 self ._init_subtokens_from_list (
511- list (self ._alphabet ) +
512521 [subtoken for _ , subtoken in new_subtoken_strings ],
513522 reserved = num_reserved_ids )
514523 tf .logging .info ("vocab_size = %d" % self .vocab_size )
0 commit comments