Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b09291d

Browse files
author
Ryan Sepassi
committed
Bump version to 1.0.6
PiperOrigin-RevId: 159970178
1 parent 3410bea commit b09291d

File tree

4 files changed

+50
-66
lines changed

4 files changed

+50
-66
lines changed

tensor2tensor/data_generators/generator_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
import io
2323
import os
2424
import tarfile
25+
import urllib
2526

2627
# Dependency imports
2728

2829
import six
2930
from six.moves import xrange # pylint: disable=redefined-builtin
30-
import six.moves.urllib_request
3131

3232
from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder
3333
from tensor2tensor.data_generators.tokenizer import Tokenizer

tensor2tensor/data_generators/image.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import cPickle
2122
import gzip
2223
import io
2324
import json
@@ -31,8 +32,6 @@
3132
import numpy as np
3233
from six.moves import xrange # pylint: disable=redefined-builtin
3334
from six.moves import zip # pylint: disable=redefined-builtin
34-
from six.moves import cPickle
35-
3635
from tensor2tensor.data_generators import generator_utils
3736

3837
import tensorflow as tf

tensor2tensor/data_generators/text_encoder.py

100755100644
Lines changed: 32 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
import six
2929
from six.moves import xrange # pylint: disable=redefined-builtin
30-
from collections import defaultdict
3130
from tensor2tensor.data_generators import tokenizer
3231

3332
import tensorflow as tf
@@ -36,10 +35,7 @@
3635
PAD = '<pad>'
3736
EOS = '<EOS>'
3837
RESERVED_TOKENS = [PAD, EOS]
39-
if six.PY2:
40-
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
41-
else:
42-
RESERVED_TOKENS_BYTES = [bytes(PAD, 'ascii'), bytes(EOS, 'ascii')]
38+
4339

4440
class TextEncoder(object):
4541
"""Base class for converting from ints to/from human readable strings."""
@@ -91,25 +87,17 @@ class ByteTextEncoder(TextEncoder):
9187
"""Encodes each byte to an id. For 8-bit strings only."""
9288

9389
def encode(self, s):
94-
numres = self._num_reserved_ids
95-
if six.PY2:
96-
return [ord(c) + numres for c in s]
97-
# Python3: explicitly convert to UTF-8
98-
return [c + numres for c in s.encode("utf-8")]
90+
return [ord(c) + self._num_reserved_ids for c in s]
9991

10092
def decode(self, ids):
101-
numres = self._num_reserved_ids
10293
decoded_ids = []
103-
int2byte = six.int2byte
10494
for id_ in ids:
105-
if 0 <= id_ < numres:
106-
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
95+
if 0 <= id_ < self._num_reserved_ids:
96+
decoded_ids.append(RESERVED_TOKENS[int(id_)])
10797
else:
108-
decoded_ids.append(int2byte(id_ - numres))
109-
if six.PY2:
110-
return ''.join(decoded_ids)
111-
# Python3: join byte arrays and then decode string
112-
return b''.join(decoded_ids).decode("utf-8")
98+
decoded_ids.append(chr(id_))
99+
100+
return ''.join(decoded_ids)
113101

114102
@property
115103
def vocab_size(self):
@@ -123,16 +111,20 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
123111
"""Initialize from a file, one token per line."""
124112
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
125113
self._reverse = reverse
126-
self._load_vocab_from_file(vocab_filename)
114+
if vocab_filename is not None:
115+
self._load_vocab_from_file(vocab_filename)
127116

128117
def encode(self, sentence):
129118
"""Converts a space-separated string of tokens to a list of ids."""
130119
ret = [self._token_to_id[tok] for tok in sentence.strip().split()]
131-
return ret[::-1] if self._reverse else ret
120+
if self._reverse:
121+
ret = ret[::-1]
122+
return ret
132123

133124
def decode(self, ids):
134-
seq = reversed(ids) if self._reverse else ids
135-
return ' '.join([self._safe_id_to_token(i) for i in seq])
125+
if self._reverse:
126+
ids = ids[::-1]
127+
return ' '.join([self._safe_id_to_token(i) for i in ids])
136128

137129
@property
138130
def vocab_size(self):
@@ -251,22 +243,15 @@ def _escaped_token_to_subtokens(self, escaped_token):
251243
"""
252244
ret = []
253245
pos = 0
254-
lesc = len(escaped_token)
255-
while pos < lesc:
256-
end = lesc
257-
while end > pos:
246+
while pos < len(escaped_token):
247+
end = len(escaped_token)
248+
while True:
258249
subtoken = self._subtoken_string_to_id.get(escaped_token[pos:end], -1)
259250
if subtoken != -1:
260251
break
261252
end -= 1
262253
ret.append(subtoken)
263-
if end > pos:
264-
pos = end
265-
else:
266-
# This kinda should not happen, but it does. Cop out by skipping the
267-
# nonexistent subtoken from the returned list.
268-
# print("Unable to find subtoken in string '{0}'".format(escaped_token))
269-
pos += 1
254+
pos = end
270255
return ret
271256

272257
@classmethod
@@ -337,13 +322,13 @@ def build_from_token_counts(self,
337322
# then count the resulting potential subtokens, keeping the ones
338323
# with high enough counts for our new vocabulary.
339324
for i in xrange(num_iterations):
340-
counts = defaultdict(int)
325+
counts = {}
341326
for token, count in six.iteritems(token_counts):
342327
escaped_token = self._escape_token(token)
343328
# we will count all tails of the escaped_token, starting from boundaries
344329
# determined by our current segmentation.
345330
if i == 0:
346-
starts = xrange(len(escaped_token))
331+
starts = list(range(len(escaped_token)))
347332
else:
348333
subtokens = self._escaped_token_to_subtokens(escaped_token)
349334
pos = 0
@@ -352,33 +337,31 @@ def build_from_token_counts(self,
352337
starts.append(pos)
353338
pos += len(self.subtoken_to_subtoken_string(subtoken))
354339
for start in starts:
355-
for end in xrange(start + 1, len(escaped_token)):
340+
for end in xrange(start + 1, len(escaped_token) + 1):
356341
subtoken_string = escaped_token[start:end]
357-
counts[subtoken_string] += count
342+
counts[subtoken_string] = counts.get(subtoken_string, 0) + count
358343
# array of lists of candidate subtoken strings, by length
359344
len_to_subtoken_strings = []
360345
for subtoken_string, count in six.iteritems(counts):
361-
lsub = len(subtoken_string)
362-
# all subtoken strings of length 1 are included regardless of count
363-
if count < min_count and lsub != 1:
346+
if count < min_count or len(subtoken_string) <= 1:
364347
continue
365-
while len(len_to_subtoken_strings) <= lsub:
348+
while len(len_to_subtoken_strings) <= len(subtoken_string):
366349
len_to_subtoken_strings.append([])
367-
len_to_subtoken_strings[lsub].append(subtoken_string)
350+
len_to_subtoken_strings[len(subtoken_string)].append(subtoken_string)
368351
new_subtoken_strings = []
369352
# consider the candidates longest to shortest, so that if we accept
370353
# a longer subtoken string, we can decrement the counts of its prefixes.
371354
for subtoken_strings in len_to_subtoken_strings[::-1]:
372355
for subtoken_string in subtoken_strings:
373356
count = counts[subtoken_string]
374-
if count < min_count and len(subtoken_string) != 1:
375-
# subtoken strings of length 1 are included regardless of count
357+
if count < min_count:
376358
continue
377359
new_subtoken_strings.append((-count, subtoken_string))
378360
for l in xrange(1, len(subtoken_string)):
379361
counts[subtoken_string[:l]] -= count
380-
# Make sure to include the underscore as a subtoken string
381-
new_subtoken_strings.append((0, '_'))
362+
# make sure we have all single characters.
363+
new_subtoken_strings.extend([(-counts.get(chr(i), 0), chr(i))
364+
for i in xrange(2**8)])
382365
new_subtoken_strings.sort()
383366
self._init_from_list([''] * self._num_reserved_ids +
384367
[p[1] for p in new_subtoken_strings])
@@ -407,19 +390,13 @@ def _load_from_file(self, filename):
407390
subtoken_strings = []
408391
with tf.gfile.Open(filename) as f:
409392
for line in f:
410-
if six.PY2:
411-
subtoken_strings.append(line.strip()[1:-1].decode('string-escape'))
412-
else:
413-
subtoken_strings.append(line.strip()[1:-1])
393+
subtoken_strings.append(line.strip()[1:-1].decode('string-escape'))
414394
self._init_from_list(subtoken_strings)
415395

416396
def _store_to_file(self, filename):
417397
with tf.gfile.Open(filename, 'w') as f:
418398
for subtoken_string in self._all_subtoken_strings:
419-
if six.PY2:
420-
f.write('\'' + subtoken_string.encode('string-escape') + '\'\n')
421-
else:
422-
f.write('\'' + subtoken_string + '\'\n')
399+
f.write('\'' + subtoken_string.encode('string-escape') + '\'\n')
423400

424401
def _escape_token(self, token):
425402
r"""Translate '\'->'\\' and '_'->'\u', then append '_'.

tensor2tensor/data_generators/tokenizer.py

100755100644
Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,29 @@
4545
from __future__ import division
4646
from __future__ import print_function
4747

48+
import array
4849
import string
4950

5051
# Dependency imports
5152

5253
from six.moves import xrange # pylint: disable=redefined-builtin
53-
from collections import defaultdict
54+
5455

5556
class Tokenizer(object):
5657
"""Vocab for breaking words into wordpieces.
5758
"""
5859

59-
_SEPARATOR_CHAR_SET = set(string.punctuation + string.whitespace)
60-
6160
def __init__(self):
62-
self.token_counts = defaultdict(int)
61+
self._separator_chars = string.punctuation + string.whitespace
62+
self._separator_char_mask = array.array(
63+
"l", [chr(i) in self._separator_chars for i in xrange(256)])
64+
self.token_counts = dict()
65+
66+
def _increment_token_count(self, token):
67+
if token in self.token_counts:
68+
self.token_counts[token] += 1
69+
else:
70+
self.token_counts[token] = 1
6371

6472
def encode(self, raw_text):
6573
"""Encode a raw string as a list of tokens.
@@ -79,11 +87,11 @@ def encode(self, raw_text):
7987
token = raw_text[token_start:pos]
8088
if token != " " or token_start == 0:
8189
ret.append(token)
82-
self.token_counts[token] += 1
90+
self._increment_token_count(token)
8391
token_start = pos
8492
final_token = raw_text[token_start:]
8593
ret.append(final_token)
86-
self.token_counts[final_token] += 1
94+
self._increment_token_count(final_token)
8795
return ret
8896

8997
def decode(self, tokens):
@@ -103,7 +111,7 @@ def decode(self, tokens):
103111
return ret
104112

105113
def _is_separator_char(self, c):
106-
return c in self._SEPARATOR_CHAR_SET
114+
return self._separator_char_mask[ord(c)]
107115

108116
def _is_word_char(self, c):
109-
return c not in self._SEPARATOR_CHAR_SET
117+
return not self._is_separator_char(c)

0 commit comments

Comments
 (0)