2323from __future__ import division
2424from __future__ import print_function
2525
26+ from collections import defaultdict
27+
2628# Dependency imports
2729
2830import six
3537PAD = '<pad>'
3638EOS = '<EOS>'
3739RESERVED_TOKENS = [PAD , EOS ]
40+ if six .PY2 :
41+ RESERVED_TOKENS_BYTES = RESERVED_TOKENS
42+ else :
43+ RESERVED_TOKENS_BYTES = [bytes (PAD , 'ascii' ), bytes (EOS , 'ascii' )]
3844
3945
4046class TextEncoder (object ):
@@ -87,17 +93,25 @@ class ByteTextEncoder(TextEncoder):
8793 """Encodes each byte to an id. For 8-bit strings only."""
8894
8995 def encode (self , s ):
90- return [ord (c ) + self ._num_reserved_ids for c in s ]
96+ numres = self ._num_reserved_ids
97+ if six .PY2 :
98+ return [ord (c ) + numres for c in s ]
99+ # Python3: explicitly convert to UTF-8
100+ return [c + numres for c in s .encode ('utf-8' )]
91101
92102 def decode (self , ids ):
103+ numres = self ._num_reserved_ids
93104 decoded_ids = []
105+ int2byte = six .int2byte
94106 for id_ in ids :
95- if 0 <= id_ < self . _num_reserved_ids :
96- decoded_ids .append (RESERVED_TOKENS [int (id_ )])
107+ if 0 <= id_ < numres :
108+ decoded_ids .append (RESERVED_TOKENS_BYTES [int (id_ )])
97109 else :
98- decoded_ids .append (chr (id_ ))
99-
100- return '' .join (decoded_ids )
110+ decoded_ids .append (int2byte (id_ - numres ))
111+ if six .PY2 :
112+ return '' .join (decoded_ids )
113+ # Python3: join byte arrays and then decode string
114+ return b'' .join (decoded_ids ).decode ('utf-8' )
101115
102116 @property
103117 def vocab_size (self ):
@@ -111,20 +125,16 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
111125 """Initialize from a file, one token per line."""
112126 super (TokenTextEncoder , self ).__init__ (num_reserved_ids = num_reserved_ids )
113127 self ._reverse = reverse
114- if vocab_filename is not None :
115- self ._load_vocab_from_file (vocab_filename )
128+ self ._load_vocab_from_file (vocab_filename )
116129
117130 def encode (self , sentence ):
118131 """Converts a space-separated string of tokens to a list of ids."""
119132 ret = [self ._token_to_id [tok ] for tok in sentence .strip ().split ()]
120- if self ._reverse :
121- ret = ret [::- 1 ]
122- return ret
133+ return ret [::- 1 ] if self ._reverse else ret
123134
124135 def decode (self , ids ):
125- if self ._reverse :
126- ids = ids [::- 1 ]
127- return ' ' .join ([self ._safe_id_to_token (i ) for i in ids ])
136+ seq = reversed (ids ) if self ._reverse else ids
137+ return ' ' .join ([self ._safe_id_to_token (i ) for i in seq ])
128138
129139 @property
130140 def vocab_size (self ):
@@ -243,15 +253,22 @@ def _escaped_token_to_subtokens(self, escaped_token):
243253 """
244254 ret = []
245255 pos = 0
246- while pos < len (escaped_token ):
247- end = len (escaped_token )
248- while True :
256+ lesc = len (escaped_token )
257+ while pos < lesc :
258+ end = lesc
259+ while end > pos :
249260 subtoken = self ._subtoken_string_to_id .get (escaped_token [pos :end ], - 1 )
250261 if subtoken != - 1 :
251262 break
252263 end -= 1
253264 ret .append (subtoken )
254- pos = end
265+ if end > pos :
266+ pos = end
267+ else :
268+ # This kinda should not happen, but it does. Cop out by skipping the
269+ # nonexistent subtoken from the returned list.
270+ # print("Unable to find subtoken in string '{0}'".format(escaped_token))
271+ pos += 1
255272 return ret
256273
257274 @classmethod
@@ -322,13 +339,13 @@ def build_from_token_counts(self,
322339 # then count the resulting potential subtokens, keeping the ones
323340 # with high enough counts for our new vocabulary.
324341 for i in xrange (num_iterations ):
325- counts = {}
342+ counts = defaultdict ( int )
326343 for token , count in six .iteritems (token_counts ):
327344 escaped_token = self ._escape_token (token )
328345 # we will count all tails of the escaped_token, starting from boundaries
329346 # determined by our current segmentation.
330347 if i == 0 :
331- starts = list ( range ( len (escaped_token ) ))
348+ starts = xrange ( len (escaped_token ))
332349 else :
333350 subtokens = self ._escaped_token_to_subtokens (escaped_token )
334351 pos = 0
@@ -337,31 +354,33 @@ def build_from_token_counts(self,
337354 starts .append (pos )
338355 pos += len (self .subtoken_to_subtoken_string (subtoken ))
339356 for start in starts :
340- for end in xrange (start + 1 , len (escaped_token ) + 1 ):
357+ for end in xrange (start + 1 , len (escaped_token )):
341358 subtoken_string = escaped_token [start :end ]
342- counts [subtoken_string ] = counts . get ( subtoken_string , 0 ) + count
359+ counts [subtoken_string ] += count
343360 # array of lists of candidate subtoken strings, by length
344361 len_to_subtoken_strings = []
345362 for subtoken_string , count in six .iteritems (counts ):
346- if count < min_count or len (subtoken_string ) <= 1 :
363+ lsub = len (subtoken_string )
364+ # all subtoken strings of length 1 are included regardless of count
365+ if count < min_count and lsub != 1 :
347366 continue
348- while len (len_to_subtoken_strings ) <= len ( subtoken_string ) :
367+ while len (len_to_subtoken_strings ) <= lsub :
349368 len_to_subtoken_strings .append ([])
350- len_to_subtoken_strings [len ( subtoken_string ) ].append (subtoken_string )
369+ len_to_subtoken_strings [lsub ].append (subtoken_string )
351370 new_subtoken_strings = []
352371 # consider the candidates longest to shortest, so that if we accept
353372 # a longer subtoken string, we can decrement the counts of its prefixes.
354373 for subtoken_strings in len_to_subtoken_strings [::- 1 ]:
355374 for subtoken_string in subtoken_strings :
356375 count = counts [subtoken_string ]
357- if count < min_count :
376+ if count < min_count and len (subtoken_string ) != 1 :
377+ # subtoken strings of length 1 are included regardless of count
358378 continue
359379 new_subtoken_strings .append ((- count , subtoken_string ))
360380 for l in xrange (1 , len (subtoken_string )):
361381 counts [subtoken_string [:l ]] -= count
362- # make sure we have all single characters.
363- new_subtoken_strings .extend ([(- counts .get (chr (i ), 0 ), chr (i ))
364- for i in xrange (2 ** 8 )])
382+ # Make sure to include the underscore as a subtoken string
383+ new_subtoken_strings .append ((0 , '_' ))
365384 new_subtoken_strings .sort ()
366385 self ._init_from_list (['' ] * self ._num_reserved_ids +
367386 [p [1 ] for p in new_subtoken_strings ])
@@ -390,13 +409,19 @@ def _load_from_file(self, filename):
390409 subtoken_strings = []
391410 with tf .gfile .Open (filename ) as f :
392411 for line in f :
393- subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('string-escape' ))
412+ if six .PY2 :
413+ subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('string-escape' ))
414+ else :
415+ subtoken_strings .append (line .strip ()[1 :- 1 ])
394416 self ._init_from_list (subtoken_strings )
395417
396418 def _store_to_file (self , filename ):
397419 with tf .gfile .Open (filename , 'w' ) as f :
398420 for subtoken_string in self ._all_subtoken_strings :
399- f .write ('\' ' + subtoken_string .encode ('string-escape' ) + '\' \n ' )
421+ if six .PY2 :
422+ f .write ('\' ' + subtoken_string .encode ('string-escape' ) + '\' \n ' )
423+ else :
424+ f .write ('\' ' + subtoken_string + '\' \n ' )
400425
401426 def _escape_token (self , token ):
402427 r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
0 commit comments