Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4b4c800

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Removing obsolete problems, merging.
PiperOrigin-RevId: 162020990
1 parent c4407b8 commit 4b4c800

File tree

18 files changed

+1394
-226
lines changed

18 files changed

+1394
-226
lines changed

tensor2tensor/bin/t2t-datagen

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ _SUPPORTED_PROBLEM_GENERATORS = {
101101
"algorithmic_algebra_inverse": (
102102
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
103103
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
104+
"ice_parsing_tokens": (
105+
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
106+
True, "ice", 2**13, 2**8),
107+
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
108+
False, "ice", 2**13, 2**8)),
109+
"ice_parsing_characters": (
110+
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True),
111+
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)),
104112
"wmt_parsing_tokens_8k": (
105113
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13),
106114
lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)),
@@ -109,11 +117,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
109117
2**14, 2**9),
110118
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
111119
2**14, 2**9)),
112-
"wsj_parsing_tokens_32k": (
113-
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True,
114-
2**15, 2**9),
115-
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
116-
2**15, 2**9)),
117120
"wmt_enfr_characters": (
118121
lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, True),
119122
lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, False)),
@@ -139,6 +142,12 @@ _SUPPORTED_PROBLEM_GENERATORS = {
139142
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
140143
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
141144
),
145+
"wmt_zhen_tokens_32k": (
146+
lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True,
147+
2**15, 2**15),
148+
lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, False,
149+
2**15, 2**15)
150+
),
142151
"lm1b_32k": (
143152
lambda: lm1b.generator(FLAGS.tmp_dir, True),
144153
lambda: lm1b.generator(FLAGS.tmp_dir, False)
@@ -159,26 +168,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
159168
"image_cifar10_test": (
160169
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000),
161170
lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)),
162-
"image_mscoco_characters_tune": (
163-
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 70000),
164-
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 10000, 70000)),
165171
"image_mscoco_characters_test": (
166172
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000),
167173
lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)),
168-
"image_mscoco_tokens_8k_tune": (
169-
lambda: image.mscoco_generator(
170-
FLAGS.tmp_dir,
171-
True,
172-
70000,
173-
vocab_filename="tokens.vocab.%d" % 2**13,
174-
vocab_size=2**13),
175-
lambda: image.mscoco_generator(
176-
FLAGS.tmp_dir,
177-
True,
178-
10000,
179-
70000,
180-
vocab_filename="tokens.vocab.%d" % 2**13,
181-
vocab_size=2**13)),
182174
"image_mscoco_tokens_8k_test": (
183175
lambda: image.mscoco_generator(
184176
FLAGS.tmp_dir,
@@ -192,20 +184,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
192184
40000,
193185
vocab_filename="tokens.vocab.%d" % 2**13,
194186
vocab_size=2**13)),
195-
"image_mscoco_tokens_32k_tune": (
196-
lambda: image.mscoco_generator(
197-
FLAGS.tmp_dir,
198-
True,
199-
70000,
200-
vocab_filename="tokens.vocab.%d" % 2**15,
201-
vocab_size=2**15),
202-
lambda: image.mscoco_generator(
203-
FLAGS.tmp_dir,
204-
True,
205-
10000,
206-
70000,
207-
vocab_filename="tokens.vocab.%d" % 2**15,
208-
vocab_size=2**15)),
209187
"image_mscoco_tokens_32k_test": (
210188
lambda: image.mscoco_generator(
211189
FLAGS.tmp_dir,
@@ -386,7 +364,7 @@ def generate_data_for_problem(problem):
386364

387365
def generate_data_for_registered_problem(problem_name):
388366
problem = registry.problem(problem_name)
389-
problem.generate_data(FLAGS.data_dir)
367+
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
390368

391369

392370
if __name__ == "__main__":

tensor2tensor/data_generators/algorithmic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class AlgorithmicIdentityBinary40(problem.Problem):
3636
def num_symbols(self):
3737
return 2
3838

39-
def generate_data(self, data_dir):
39+
def generate_data(self, data_dir, _):
4040
utils.generate_files(
4141
identity_generator(self.num_symbols, 40, 100000),
4242
self.training_filepaths(data_dir, 100))

tensor2tensor/data_generators/generator_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,46 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
300300
return vocab
301301

302302

303+
def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
304+
index, vocab_filename, vocab_size):
305+
r"""Generate a vocabulary from a tabbed source file.
306+
307+
The source is a file of source, target pairs, where each line contains
308+
a source string and a target string, separated by a tab ('\t') character.
309+
The index parameter specifies 0 for the source or 1 for the target.
310+
311+
Args:
312+
tmp_dir: path to the temporary directory.
313+
source_filename: the name of the tab-separated source file.
314+
index: index.
315+
vocab_filename: the name of the vocabulary file.
316+
vocab_size: vocabulary size.
317+
318+
Returns:
319+
The vocabulary.
320+
"""
321+
vocab_filepath = os.path.join(tmp_dir, vocab_filename)
322+
if os.path.exists(vocab_filepath):
323+
vocab = text_encoder.SubwordTextEncoder(vocab_filepath)
324+
return vocab
325+
326+
# Use Tokenizer to count the word occurrences.
327+
filepath = os.path.join(tmp_dir, source_filename)
328+
with tf.gfile.GFile(filepath, mode="r") as source_file:
329+
for line in source_file:
330+
line = line.strip()
331+
if line and "\t" in line:
332+
parts = line.split("\t", maxsplit=1)
333+
part = parts[index].strip()
334+
_ = tokenizer.encode(text_encoder.native_to_unicode(part))
335+
336+
vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
337+
vocab_size, tokenizer.token_counts, 1,
338+
min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS))
339+
vocab.store_to_file(vocab_filepath)
340+
return vocab
341+
342+
303343
def read_records(filename):
304344
reader = tf.python_io.tf_record_iterator(filename)
305345
records = []

tensor2tensor/data_generators/image.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
from six.moves import xrange # pylint: disable=redefined-builtin
3434
from six.moves import zip # pylint: disable=redefined-builtin
3535
from tensor2tensor.data_generators import generator_utils
36+
from tensor2tensor.data_generators import problem
37+
from tensor2tensor.data_generators import text_encoder
38+
from tensor2tensor.utils import registry
3639

3740
import tensorflow as tf
3841

@@ -300,3 +303,47 @@ def mscoco_generator(tmp_dir,
300303
"image/height": [height],
301304
"image/width": [width]
302305
}
306+
307+
# French street names dataset.
308+
309+
310+
@registry.register_problem
311+
class ImageFSNS(problem.Problem):
312+
"""Problem spec for French Street Name recognition."""
313+
314+
def generate_data(self, data_dir, tmp_dir):
315+
list_url = ("https://raw.githubusercontent.com/tensorflow/models/master/"
316+
"street/python/fsns_urls.txt")
317+
fsns_urls = generator_utils.maybe_download(
318+
tmp_dir, "fsns_urls.txt", list_url)
319+
fsns_files = [f.strip() for f in open(fsns_urls, "r")
320+
if f.startswith("http://")]
321+
for url in fsns_files:
322+
if "/train/train" in url:
323+
generator_utils.maybe_download(
324+
data_dir, "image_fsns-train" + url[-len("-00100-of-00512"):], url)
325+
elif "/validation/validation" in url:
326+
generator_utils.maybe_download(
327+
data_dir, "image_fsns-dev" + url[-len("-00100-of-00512"):], url)
328+
elif "charset" in url:
329+
generator_utils.maybe_download(
330+
data_dir, "charset_size134.txt", url)
331+
332+
def hparams(self, defaults, model_hparams):
333+
p = defaults
334+
p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)}
335+
# This vocab file must be present within the data directory.
336+
vocab_filename = os.path.join(model_hparams.data_dir, "charset_size134.txt")
337+
subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename)
338+
p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size)
339+
p.vocabulary = {
340+
"inputs": text_encoder.TextEncoder(),
341+
"targets": subtokenizer,
342+
}
343+
p.batch_size_multiplier = 256
344+
p.max_expected_batch_size_per_shard = 2
345+
vocab_size = 144
346+
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)}
347+
p.target_modality = (registry.Modalities.SYMBOL, vocab_size)
348+
p.input_space_id = problem.SpaceID.DIGIT_0
349+
p.target_space_id = problem.SpaceID.DIGIT_1

tensor2tensor/data_generators/problem.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ class SpaceID(object):
5959
PARSE_CHR = 14
6060
# Parse tokens
6161
PARSE_TOK = 15
62+
# Chinese tokens
63+
ZH_TOK = 16
64+
# Icelandic characters
65+
ICE_CHAR = 17
66+
# Icelandic tokens
67+
ICE_TOK = 18
68+
# Icelandic parse tokens
69+
ICE_PARSE_TOK = 19
6270

6371

6472
class Problem(object):
@@ -97,7 +105,7 @@ class Problem(object):
97105
# BEGIN SUBCLASS INTERFACE
98106
# ============================================================================
99107

100-
def generate_data(self, data_dir):
108+
def generate_data(self, data_dir, tmp_dir):
101109
raise NotImplementedError()
102110

103111
def hparams(self, defaults, model_hparams):

0 commit comments

Comments
 (0)