2727
2828import six
2929from six .moves import xrange # pylint: disable=redefined-builtin
30- from collections import defaultdict
3130from tensor2tensor .data_generators import tokenizer
3231
3332import tensorflow as tf
3635PAD = '<pad>'
3736EOS = '<EOS>'
3837RESERVED_TOKENS = [PAD , EOS ]
39- if six .PY2 :
40- RESERVED_TOKENS_BYTES = RESERVED_TOKENS
41- else :
42- RESERVED_TOKENS_BYTES = [bytes (PAD , 'ascii' ), bytes (EOS , 'ascii' )]
38+
4339
4440class TextEncoder (object ):
4541 """Base class for converting from ints to/from human readable strings."""
@@ -91,25 +87,17 @@ class ByteTextEncoder(TextEncoder):
9187 """Encodes each byte to an id. For 8-bit strings only."""
9288
9389 def encode (self , s ):
94- numres = self ._num_reserved_ids
95- if six .PY2 :
96- return [ord (c ) + numres for c in s ]
97- # Python3: explicitly convert to UTF-8
98- return [c + numres for c in s .encode ("utf-8" )]
90+ return [ord (c ) + self ._num_reserved_ids for c in s ]
9991
10092 def decode (self , ids ):
101- numres = self ._num_reserved_ids
10293 decoded_ids = []
103- int2byte = six .int2byte
10494 for id_ in ids :
105- if 0 <= id_ < numres :
106- decoded_ids .append (RESERVED_TOKENS_BYTES [int (id_ )])
95+ if 0 <= id_ < self . _num_reserved_ids :
96+ decoded_ids .append (RESERVED_TOKENS [int (id_ )])
10797 else :
108- decoded_ids .append (int2byte (id_ - numres ))
109- if six .PY2 :
110- return '' .join (decoded_ids )
111- # Python3: join byte arrays and then decode string
112- return b'' .join (decoded_ids ).decode ("utf-8" )
98+ decoded_ids .append (chr (id_ ))
99+
100+ return '' .join (decoded_ids )
113101
114102 @property
115103 def vocab_size (self ):
@@ -123,16 +111,20 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
123111 """Initialize from a file, one token per line."""
124112 super (TokenTextEncoder , self ).__init__ (num_reserved_ids = num_reserved_ids )
125113 self ._reverse = reverse
126- self ._load_vocab_from_file (vocab_filename )
114+ if vocab_filename is not None :
115+ self ._load_vocab_from_file (vocab_filename )
127116
128117 def encode (self , sentence ):
129118 """Converts a space-separated string of tokens to a list of ids."""
130119 ret = [self ._token_to_id [tok ] for tok in sentence .strip ().split ()]
131- return ret [::- 1 ] if self ._reverse else ret
120+ if self ._reverse :
121+ ret = ret [::- 1 ]
122+ return ret
132123
133124 def decode (self , ids ):
134- seq = reversed (ids ) if self ._reverse else ids
135- return ' ' .join ([self ._safe_id_to_token (i ) for i in seq ])
125+ if self ._reverse :
126+ ids = ids [::- 1 ]
127+ return ' ' .join ([self ._safe_id_to_token (i ) for i in ids ])
136128
137129 @property
138130 def vocab_size (self ):
@@ -251,22 +243,15 @@ def _escaped_token_to_subtokens(self, escaped_token):
251243 """
252244 ret = []
253245 pos = 0
254- lesc = len (escaped_token )
255- while pos < lesc :
256- end = lesc
257- while end > pos :
246+ while pos < len (escaped_token ):
247+ end = len (escaped_token )
248+ while True :
258249 subtoken = self ._subtoken_string_to_id .get (escaped_token [pos :end ], - 1 )
259250 if subtoken != - 1 :
260251 break
261252 end -= 1
262253 ret .append (subtoken )
263- if end > pos :
264- pos = end
265- else :
266- # This kinda should not happen, but it does. Cop out by skipping the
267- # nonexistent subtoken from the returned list.
268- # print("Unable to find subtoken in string '{0}'".format(escaped_token))
269- pos += 1
254+ pos = end
270255 return ret
271256
272257 @classmethod
@@ -337,13 +322,13 @@ def build_from_token_counts(self,
337322 # then count the resulting potential subtokens, keeping the ones
338323 # with high enough counts for our new vocabulary.
339324 for i in xrange (num_iterations ):
340- counts = defaultdict ( int )
325+ counts = {}
341326 for token , count in six .iteritems (token_counts ):
342327 escaped_token = self ._escape_token (token )
343328 # we will count all tails of the escaped_token, starting from boundaries
344329 # determined by our current segmentation.
345330 if i == 0 :
346- starts = xrange ( len (escaped_token ))
331+ starts = list ( range ( len (escaped_token ) ))
347332 else :
348333 subtokens = self ._escaped_token_to_subtokens (escaped_token )
349334 pos = 0
@@ -352,33 +337,31 @@ def build_from_token_counts(self,
352337 starts .append (pos )
353338 pos += len (self .subtoken_to_subtoken_string (subtoken ))
354339 for start in starts :
355- for end in xrange (start + 1 , len (escaped_token )):
340+ for end in xrange (start + 1 , len (escaped_token ) + 1 ):
356341 subtoken_string = escaped_token [start :end ]
357- counts [subtoken_string ] += count
342+ counts [subtoken_string ] = counts . get ( subtoken_string , 0 ) + count
358343 # array of lists of candidate subtoken strings, by length
359344 len_to_subtoken_strings = []
360345 for subtoken_string , count in six .iteritems (counts ):
361- lsub = len (subtoken_string )
362- # all subtoken strings of length 1 are included regardless of count
363- if count < min_count and lsub != 1 :
346+ if count < min_count or len (subtoken_string ) <= 1 :
364347 continue
365- while len (len_to_subtoken_strings ) <= lsub :
348+ while len (len_to_subtoken_strings ) <= len ( subtoken_string ) :
366349 len_to_subtoken_strings .append ([])
367- len_to_subtoken_strings [lsub ].append (subtoken_string )
350+ len_to_subtoken_strings [len ( subtoken_string ) ].append (subtoken_string )
368351 new_subtoken_strings = []
369352 # consider the candidates longest to shortest, so that if we accept
370353 # a longer subtoken string, we can decrement the counts of its prefixes.
371354 for subtoken_strings in len_to_subtoken_strings [::- 1 ]:
372355 for subtoken_string in subtoken_strings :
373356 count = counts [subtoken_string ]
374- if count < min_count and len (subtoken_string ) != 1 :
375- # subtoken strings of length 1 are included regardless of count
357+ if count < min_count :
376358 continue
377359 new_subtoken_strings .append ((- count , subtoken_string ))
378360 for l in xrange (1 , len (subtoken_string )):
379361 counts [subtoken_string [:l ]] -= count
380- # Make sure to include the underscore as a subtoken string
381- new_subtoken_strings .append ((0 , '_' ))
362+ # make sure we have all single characters.
363+ new_subtoken_strings .extend ([(- counts .get (chr (i ), 0 ), chr (i ))
364+ for i in xrange (2 ** 8 )])
382365 new_subtoken_strings .sort ()
383366 self ._init_from_list (['' ] * self ._num_reserved_ids +
384367 [p [1 ] for p in new_subtoken_strings ])
@@ -407,19 +390,13 @@ def _load_from_file(self, filename):
407390 subtoken_strings = []
408391 with tf .gfile .Open (filename ) as f :
409392 for line in f :
410- if six .PY2 :
411- subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('string-escape' ))
412- else :
413- subtoken_strings .append (line .strip ()[1 :- 1 ])
393+ subtoken_strings .append (line .strip ()[1 :- 1 ].decode ('string-escape' ))
414394 self ._init_from_list (subtoken_strings )
415395
416396 def _store_to_file (self , filename ):
417397 with tf .gfile .Open (filename , 'w' ) as f :
418398 for subtoken_string in self ._all_subtoken_strings :
419- if six .PY2 :
420- f .write ('\' ' + subtoken_string .encode ('string-escape' ) + '\' \n ' )
421- else :
422- f .write ('\' ' + subtoken_string + '\' \n ' )
399+ f .write ('\' ' + subtoken_string .encode ('string-escape' ) + '\' \n ' )
423400
424401 def _escape_token (self , token ):
425402 r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
0 commit comments