@@ -175,9 +175,9 @@ class SubwordTextEncoder(TextEncoder):
175175 """
176176
177177 def __init__ (self , filename = None , num_reserved_ids = 2 ):
178- """Read from a file."""
179178 self ._tokenizer = tokenizer .Tokenizer ()
180179 if filename is not None :
180+ # Read from a file.
181181 self ._load_from_file (filename )
182182
183183 super (SubwordTextEncoder , self ).__init__ (num_reserved_ids = num_reserved_ids )
@@ -235,14 +235,13 @@ def _subtokens_to_tokens(self, subtokens):
235235
236236 def subtoken_to_subtoken_string (self , subtoken ):
237237 """Subtoken_String (string) corresponding to the given subtoken (id)."""
238- if (subtoken >= 0 and subtoken < self .vocab_size and
239- self ._all_subtoken_strings [subtoken ]):
240- return self ._all_subtoken_strings [subtoken ]
241- else :
242- if 0 <= subtoken < self ._num_reserved_ids :
243- return '%s_' % RESERVED_TOKENS [subtoken ]
244- else :
245- return 'ID%d_' % subtoken
238+ if 0 <= subtoken < self .vocab_size :
239+ subtoken_string = self ._all_subtoken_strings [subtoken ]
240+ if subtoken_string :
241+ return subtoken_string
242+ if 0 <= subtoken < self ._num_reserved_ids :
243+ return '%s_' % RESERVED_TOKENS [subtoken ]
244+ return 'ID%d_' % subtoken
246245
247246 def _escaped_token_to_subtokens (self , escaped_token ):
248247 """Converts an escaped token string to a list of subtokens.
@@ -262,21 +261,32 @@ def _escaped_token_to_subtokens(self, escaped_token):
262261 if subtoken != - 1 :
263262 break
264263 end -= 1
265- ret .append (subtoken )
266264 if end > pos :
265+ ret .append (subtoken )
267266 pos = end
268267 else :
269- # This kinda should not happen, but it does. Cop out by skipping the
270- # nonexistent subtoken from the returned list.
271- # print("Unable to find subtoken in string '{0}'".format(escaped_token))
268+ # No subtoken in the vocabulary matches escaped_token[pos].
269+ # This can happen if the token contains a Unicode character
270+ # that did not occur in the vocabulary training set.
271+ # The id self.vocab_size - 1 is decoded as Unicode uFFFD,
272+ # REPLACEMENT_CHARACTER.
273+ ret .append (self .vocab_size - 1 )
274+ # Ensure that the outer loop continues
272275 pos += 1
273276 return ret
274277
278+ @classmethod
279+ def alphabet (cls , token_counts ):
280+ """Return the set of Unicode characters that appear in the tokens"""
281+ alphabet_set = set ()
282+ for token in six .iterkeys (token_counts ):
283+ alphabet_set |= set (token )
284+ return alphabet_set
285+
275286 @classmethod
276287 def build_to_target_size (cls ,
277288 target_size ,
278289 token_counts ,
279- store_filename ,
280290 min_val ,
281291 max_val ,
282292 num_iterations = 4 ):
@@ -296,50 +306,51 @@ def build_to_target_size(cls,
296306 Returns:
297307 a SubwordTextEncoder instance.
298308 """
299- present_count = (max_val + min_val ) // 2
300- tf .logging .info ('Trying min_count %d' % present_count )
301- subtokenizer = cls ()
302- subtokenizer .build_from_token_counts (token_counts , store_filename ,
303- present_count , num_iterations )
304-
305- if min_val >= max_val or subtokenizer .vocab_size == target_size :
306- return subtokenizer
307- elif subtokenizer .vocab_size > target_size :
308- other_subtokenizer = cls .build_to_target_size (
309- target_size , token_counts , store_filename , present_count + 1 , max_val ,
310- num_iterations )
311- if (abs (other_subtokenizer .vocab_size - target_size ) <
312- abs (subtokenizer .vocab_size - target_size )):
313- return other_subtokenizer
314- else :
309+
310+ # Calculate the alphabet, i.e. the set of all Unicode characters
311+ # that appear in the tokens
312+ alphabet_set = cls .alphabet (token_counts )
313+ tf .logging .info ('Alphabet contains %d characters' % len (alphabet_set ))
314+
315+ def bisect (min_val , max_val ):
316+ present_count = (max_val + min_val ) // 2
317+ tf .logging .info ('Trying min_count %d' % present_count )
318+ subtokenizer = cls ()
319+ subtokenizer .build_from_token_counts (token_counts , alphabet_set ,
320+ present_count , num_iterations )
321+
322+ if min_val >= max_val or subtokenizer .vocab_size == target_size :
315323 return subtokenizer
316- else :
317- other_subtokenizer = cls . build_to_target_size (
318- target_size , token_counts , store_filename , min_val , present_count - 1 ,
319- num_iterations )
324+ if subtokenizer . vocab_size > target_size :
325+ other_subtokenizer = bisect ( present_count + 1 , max_val )
326+ else :
327+ other_subtokenizer = bisect ( min_val , present_count - 1 )
320328 if (abs (other_subtokenizer .vocab_size - target_size ) <
321329 abs (subtokenizer .vocab_size - target_size )):
322330 return other_subtokenizer
323331 else :
324332 return subtokenizer
325333
334+ return bisect (min_val , max_val )
335+
326336 def build_from_token_counts (self ,
327337 token_counts ,
328- store_filename ,
338+ alphabet_set ,
329339 min_count ,
330340 num_iterations = 4 ):
331341 """Train a SubwordTextEncoder based on a dictionary of word counts.
332342
333343 Args:
334- token_counts: a dictionary of string to int.
335- store_filename: a string - where to write the vocabulary .
344+ token_counts: a dictionary of Unicode strings to int.
345+ alphabet_set: the set of Unicode characters that appear in the tokens .
336346 min_count: an integer - discard subtokens with lower counts.
337347 num_iterations: an integer. how many iterations of refinement.
338348 """
339349 # We build iteratively. On each iteration, we segment all the words,
340350 # then count the resulting potential subtokens, keeping the ones
341351 # with high enough counts for our new vocabulary.
342352 for i in xrange (num_iterations ):
353+ tf .logging .info ("Iteration {0}" .format (i ))
343354 counts = defaultdict (int )
344355 for token , count in six .iteritems (token_counts ):
345356 escaped_token = self ._escape_token (token )
@@ -353,39 +364,49 @@ def build_from_token_counts(self,
353364 starts = []
354365 for subtoken in subtokens :
355366 starts .append (pos )
356- pos += len (self .subtoken_to_subtoken_string ( subtoken ) )
367+ pos += len (self ._all_subtoken_strings [ subtoken ] )
357368 for start in starts :
358- for end in xrange (start + 1 , len (escaped_token )):
369+ for end in xrange (start + 1 , len (escaped_token ) + 1 ):
359370 subtoken_string = escaped_token [start :end ]
360371 counts [subtoken_string ] += count
361- # array of lists of candidate subtoken strings, by length
372+ # Array of sets of candidate subtoken strings, by length
362373 len_to_subtoken_strings = []
363374 for subtoken_string , count in six .iteritems (counts ):
364375 lsub = len (subtoken_string )
365- # all subtoken strings of length 1 are included regardless of count
366- if count < min_count and lsub != 1 :
376+ # All subtoken strings of length 1 are automatically included
377+ # later, so we don't need to consider them here
378+ if count < min_count or lsub <= 1 :
367379 continue
380+ # Add this subtoken string to its length set
368381 while len (len_to_subtoken_strings ) <= lsub :
369- len_to_subtoken_strings .append ([] )
370- len_to_subtoken_strings [lsub ].append (subtoken_string )
382+ len_to_subtoken_strings .append (set () )
383+ len_to_subtoken_strings [lsub ].add (subtoken_string )
371384 new_subtoken_strings = []
372385 # consider the candidates longest to shortest, so that if we accept
373386 # a longer subtoken string, we can decrement the counts of its prefixes.
374- for subtoken_strings in len_to_subtoken_strings [:: - 1 ] :
387+ for subtoken_strings in reversed ( len_to_subtoken_strings [2 :]) :
375388 for subtoken_string in subtoken_strings :
376389 count = counts [subtoken_string ]
377- if count < min_count and len (subtoken_string ) != 1 :
378- # subtoken strings of length 1 are included regardless of count
390+ if count < min_count :
379391 continue
380- new_subtoken_strings .append ((- count , subtoken_string ))
392+ new_subtoken_strings .append ((count , subtoken_string ))
381393 for l in xrange (1 , len (subtoken_string )):
382394 counts [subtoken_string [:l ]] -= count
383- # Make sure to include the underscore as a subtoken string
384- new_subtoken_strings .append ((0 , '_' ))
385- new_subtoken_strings .sort ()
386- self ._init_from_list (['' ] * self ._num_reserved_ids +
395+ # Sort what we've got so far in decreasing order by count
396+ new_subtoken_strings .sort (reverse = True )
397+ # Add the alphabet set at the end of the vocabulary list
398+ for char in alphabet_set :
399+ new_subtoken_strings .append ((0 , char ))
400+ # Also include the Unicode REPLACEMENT CHARACTER to use
401+ # when encountering previously unseen Unicode characters
402+ # in the input (i.e. input external to the tokenizer training
403+ # set, which may thus contain characters not in the alphabet_set).
404+ # This must be the last entry in the subtoken vocabulary list.
405+ new_subtoken_strings .append ((0 , u'\uFFFD ' ))
406+ # Now we have a candidate vocabulary
407+ self ._init_from_list ([u'' ] * self ._num_reserved_ids +
387408 [p [1 ] for p in new_subtoken_strings ])
388- print ('vocab_size = %d' % self .vocab_size )
409+ tf . logging . info ('vocab_size = %d' % self .vocab_size )
389410
390411 original = 'This sentence was encoded by the SubwordTextEncoder.'
391412 encoded = self .encode (original )
@@ -394,33 +415,33 @@ def build_from_token_counts(self,
394415 decoded = self .decode (encoded )
395416 print (decoded )
396417 assert decoded == original
397- self ._store_to_file (store_filename )
418+
419+ def dump (self ):
420+ """ Debugging dump of the current subtoken vocabulary """
421+ subtoken_strings = [(i , s ) for s , i in six .iteritems (self ._subtoken_string_to_id )]
422+ print (u", " .join (u"{0} : '{1}'" .format (i , s ) for i , s in sorted (subtoken_strings )))
398423
399424 def _init_from_list (self , subtoken_strings ):
400425 """Initialize from a list of subtoken strings."""
401426 self ._all_subtoken_strings = subtoken_strings
402- self ._subtoken_string_to_id = {}
403- for i in xrange (len (subtoken_strings )):
404- subtoken_string = subtoken_strings [i ]
405- if subtoken_string :
406- self ._subtoken_string_to_id [subtoken_string ] = i
427+ self ._subtoken_string_to_id = { s : i for i , s in enumerate (subtoken_strings ) if s }
407428
408429 def _load_from_file (self , filename ):
409430 """Load from a file."""
410431 subtoken_strings = []
411432 with tf .gfile .Open (filename ) as f :
412433 for line in f :
413434 if six .PY2 :
414- subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('string-escape ' ))
435+ subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('utf-8 ' ))
415436 else :
416437 subtoken_strings .append (line .strip ()[1 :- 1 ])
417438 self ._init_from_list (subtoken_strings )
418439
419- def _store_to_file (self , filename ):
440+ def store_to_file (self , filename ):
420441 with tf .gfile .Open (filename , 'w' ) as f :
421442 for subtoken_string in self ._all_subtoken_strings :
422443 if six .PY2 :
423- f .write ('\' ' + subtoken_string .encode ('string-escape ' ) + '\' \n ' )
444+ f .write ('\' ' + subtoken_string .encode ('utf-8 ' ) + '\' \n ' )
424445 else :
425446 f .write ('\' ' + subtoken_string + '\' \n ' )
426447
@@ -437,43 +458,26 @@ def _escape_token(self, token):
437458 def _unescape_token (self , escaped_token ):
438459 r"""Remove '_' from end, then translate '\\'->'\' and '\u'->'_'.
439460
440- TODO(noam): There must be some better way to do this with regexps.
441-
442461 Args:
443462 escaped_token: a string
444463 Returns:
445464 token: a string
446465 """
447466 assert escaped_token [- 1 ] == '_'
448- escaped_token = escaped_token [:- 1 ]
449- if '\\ ' not in escaped_token :
450- return escaped_token
451- ret = ''
452- pos = 0
453- while pos < len (escaped_token ):
454- if escaped_token [pos ] == '\\ ' and pos + 1 < len (escaped_token ):
455- if escaped_token [pos + 1 ] == 'u' :
456- ret += '_'
457- else :
458- ret += escaped_token [pos + 1 ]
459- pos += 1
460- pos += 1
461- return ret
467+ return escaped_token [:- 1 ].replace ('\\ u' , '_' ).replace ('\\ \\ ' , '\\ ' )
462468
463469 @classmethod
464470 def get_token_counts (cls , text_filepattern , corpus_max_lines ):
465- """Read the corpus and compute a dictionary of word counts."""
471+ """Read the corpus and compute a dictionary of token counts."""
466472 tok = tokenizer .Tokenizer ()
467- token_counts = {}
468473 lines_read = 0
469474 filenames = tf .gfile .Glob (text_filepattern )
470475 for text_filename in filenames :
471476 with tf .gfile .Open (text_filename ) as f :
472477 for line in f :
473- tokens = tok .encode (line .strip ())
474- for t in tokens :
475- token_counts [t ] = token_counts .get (t , 0 ) + 1
478+ # The tokenizer updates token_counts in encode()
479+ tok .encode (line .strip ())
476480 lines_read += 1
477481 if corpus_max_lines > 0 and lines_read > corpus_max_lines :
478- return token_counts
479- return token_counts
482+ return tok . token_counts
483+ return tok . token_counts
0 commit comments