Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f6799b9

Browse files
author
Ryan Sepassi
committed
File/code moves
PiperOrigin-RevId: 164058229
1 parent 9394d0e commit f6799b9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1648
-1512
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ python -c "from tensor2tensor.models.transformer import Transformer"
180180
**Datasets** are all standardized on `TFRecord` files with `tensorflow.Example`
181181
protocol buffers. All datasets are registered and generated with the
182182
[data
183-
generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/generator.py)
183+
generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-datagen)
184184
and many common sequence datasets are already available for generation and use.
185185

186186
### Problems and Modalities

tensor2tensor/data_generators/generator.py renamed to tensor2tensor/bin/t2t-datagen

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python
12
# coding=utf-8
23
# Copyright 2017 The Tensor2Tensor Authors.
34
#
@@ -15,14 +16,15 @@
1516

1617
"""Produces the training and dev data for --problem into --data_dir.
1718
18-
generator.py produces sharded and shuffled TFRecord files of tensorflow.Example
19-
protocol buffers for a variety of datasets registered in this file.
19+
Produces sharded and shuffled TFRecord files of tensorflow.Example protocol
20+
buffers for a variety of registered datasets.
2021
21-
All datasets are registered in _SUPPORTED_PROBLEM_GENERATORS. Each entry maps a
22-
string name (selectable on the command-line with --problem) to a function that
23-
takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
24-
yields for each training example a dictionary mapping string feature names to
25-
lists of {string, int, float}. The generator will be run once for each mode.
22+
All Problems are registered with @registry.register_problem or are in
23+
_SUPPORTED_PROBLEM_GENERATORS in this file. Each entry maps a string name
24+
(selectable on the command-line with --problem) to a function that takes 2
25+
arguments - input_directory and mode (one of "train" or "dev") - and yields for
26+
each training example a dictionary mapping string feature names to lists of
27+
{string, int, float}. The generator will be run once for each mode.
2628
"""
2729
from __future__ import absolute_import
2830
from __future__ import division
@@ -228,8 +230,7 @@ def generate_data_for_problem(problem):
228230
num_shards = FLAGS.num_shards or 10
229231
tf.logging.info("Generating training data for %s.", problem)
230232
train_output_files = generator_utils.train_data_filenames(
231-
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
232-
num_shards)
233+
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards)
233234
generator_utils.generate_files(training_gen(), train_output_files,
234235
FLAGS.max_cases)
235236
tf.logging.info("Generating development data for %s.", problem)
@@ -249,9 +250,10 @@ def generate_data_for_registered_problem(problem_name):
249250
raise ValueError("--num_shards should not be set for registered Problem.")
250251
problem = registry.problem(problem_name)
251252
task_id = None if FLAGS.task_id < 0 else FLAGS.task_id
252-
problem.generate_data(os.path.expanduser(FLAGS.data_dir),
253-
os.path.expanduser(FLAGS.tmp_dir),
254-
task_id=task_id)
253+
problem.generate_data(
254+
os.path.expanduser(FLAGS.data_dir),
255+
os.path.expanduser(FLAGS.tmp_dir),
256+
task_id=task_id)
255257

256258

257259
if __name__ == "__main__":

tensor2tensor/trainer.py renamed to tensor2tensor/bin/t2t-trainer

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#!/usr/bin/env python
12
# coding=utf-8
23
# Copyright 2017 The Tensor2Tensor Authors.
34
#

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tensor2tensor.data_generators import wmt
3131
from tensor2tensor.data_generators import wsj_parsing
3232

33+
3334
# Problem modules that require optional dependencies
3435
# pylint: disable=g-import-not-at-top
3536
try:

tensor2tensor/data_generators/image.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from tensor2tensor.data_generators import generator_utils
3737
from tensor2tensor.data_generators import problem
3838
from tensor2tensor.data_generators import text_encoder
39-
from tensor2tensor.models import common_layers
39+
from tensor2tensor.layers import common_layers
4040
from tensor2tensor.utils import registry
4141

4242
import tensorflow as tf
@@ -76,10 +76,11 @@ class ImageFSNS(ImageProblem):
7676
def generate_data(self, data_dir, tmp_dir, task_id=-1):
7777
list_url = ("https://raw.githubusercontent.com/tensorflow/models/master/"
7878
"street/python/fsns_urls.txt")
79-
fsns_urls = generator_utils.maybe_download(
80-
tmp_dir, "fsns_urls.txt", list_url)
81-
fsns_files = [f.strip() for f in open(fsns_urls, "r")
82-
if f.startswith("http://")]
79+
fsns_urls = generator_utils.maybe_download(tmp_dir, "fsns_urls.txt",
80+
list_url)
81+
fsns_files = [
82+
f.strip() for f in open(fsns_urls, "r") if f.startswith("http://")
83+
]
8384
for url in fsns_files:
8485
if "/train/train" in url:
8586
generator_utils.maybe_download(
@@ -88,8 +89,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
8889
generator_utils.maybe_download(
8990
data_dir, "image_fsns-dev" + url[-len("-00100-of-00512"):], url)
9091
elif "charset" in url:
91-
generator_utils.maybe_download(
92-
data_dir, "charset_size134.txt", url)
92+
generator_utils.maybe_download(data_dir, "charset_size134.txt", url)
9393

9494
def feature_encoders(self, data_dir):
9595
# This vocab file must be present within the data directory.
@@ -111,8 +111,8 @@ def hparams(self, defaults, model_hparams):
111111

112112
def example_reading_spec(self):
113113
label_key = "image/unpadded_label"
114-
return super(ImageFSNS, self).example_reading_spec(self,
115-
label_key=label_key)
114+
return super(ImageFSNS, self).example_reading_spec(
115+
self, label_key=label_key)
116116

117117

118118
class Image2ClassProblem(ImageProblem):
@@ -161,6 +161,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
161161

162162
def imagenet_preprocess_examples(examples, mode):
163163
"""Preprocessing used for Imagenet and similar problems."""
164+
164165
def preprocess(img):
165166
img = tf.image.resize_images(img, [360, 360])
166167
img = common_layers.image_augmentation(tf.to_float(img) / 255.)
@@ -215,8 +216,8 @@ def is_small(self):
215216

216217
def preprocess_examples(self, examples, mode):
217218
examples = imagenet_preprocess_examples(examples, mode)
218-
examples["inputs"] = tf.to_int64(tf.image.resize_images(
219-
examples["inputs"], [32, 32]))
219+
examples["inputs"] = tf.to_int64(
220+
tf.image.resize_images(examples["inputs"], [32, 32]))
220221

221222

222223
def image_generator(images, labels):
@@ -665,12 +666,20 @@ def generator(self, data_dir, tmp_dir, is_training):
665666
vocab_filename = "vocab.endefr.%d" % self.targeted_vocab_size
666667
if is_training:
667668
return mscoco_generator(
668-
data_dir, tmp_dir, True, 80000,
669-
vocab_filename=vocab_filename, vocab_size=self.targeted_vocab_size)
669+
data_dir,
670+
tmp_dir,
671+
True,
672+
80000,
673+
vocab_filename=vocab_filename,
674+
vocab_size=self.targeted_vocab_size)
670675
else:
671676
return mscoco_generator(
672-
data_dir, tmp_dir, False, 40000,
673-
vocab_filename=vocab_filename, vocab_size=self.targeted_vocab_size)
677+
data_dir,
678+
tmp_dir,
679+
False,
680+
40000,
681+
vocab_filename=vocab_filename,
682+
vocab_size=self.targeted_vocab_size)
674683

675684

676685
@registry.register_problem
@@ -690,8 +699,8 @@ def targeted_vocab_size(self):
690699
def _get_celeba(directory):
691700
"""Download and extract CELEBA to directory unless it is there."""
692701
# path = os.path.join(directory, _CELEBA_NAME)
693-
path = generator_utils.maybe_download_from_drive(directory,
694-
_CELEBA_NAME, _CELEBA_URL)
702+
path = generator_utils.maybe_download_from_drive(directory, _CELEBA_NAME,
703+
_CELEBA_URL)
695704
if not tf.gfile.Exists(path):
696705
zipfile.ZipFile(path + ".zip", "r").extractall(directory)
697706

@@ -711,7 +720,7 @@ def celeba_generator(tmp_dir, how_many, start_from=0):
711720
"""
712721
_get_celeba(tmp_dir)
713722
image_files = tf.gfile.Glob(os.path.join(tmp_dir, _CELEBA_NAME) + "/*.jpg")
714-
for filename in image_files[start_from:start_from+how_many]:
723+
for filename in image_files[start_from:start_from + how_many]:
715724
with tf.gfile.Open(filename, "r") as f:
716725
encoded_image_data = f.read()
717726
yield {

tensor2tensor/data_generators/problem_hparams.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# Dependency imports
2626

2727
from tensor2tensor.data_generators import text_encoder
28-
from tensor2tensor.models import modalities # pylint: disable=unused-import
28+
from tensor2tensor.layers import modalities # pylint: disable=unused-import
2929
from tensor2tensor.utils import registry
3030

3131
import tensorflow as tf
@@ -202,8 +202,7 @@ def default_problem_hparams():
202202
# the targets. For instance `problem_copy` will copy the inputs, but
203203
# `problem_rev_copy` will copy the targets.
204204
was_reversed=False,
205-
was_copy=False,
206-
)
205+
was_copy=False,)
207206

208207

209208
def test_problem_hparams(unused_model_hparams, input_vocab_size,
@@ -327,9 +326,7 @@ def lm1b_32k(model_hparams):
327326
encoder = text_encoder.SubwordTextEncoder(
328327
os.path.join(model_hparams.data_dir, "lm1b_32k.subword_text_encoder"))
329328
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
330-
p.vocabulary = {
331-
"targets": encoder
332-
}
329+
p.vocabulary = {"targets": encoder}
333330
p.target_space_id = 3
334331
return p
335332

@@ -343,9 +340,7 @@ def lm1b_characters(unused_model_hparams):
343340
p.input_modality = {}
344341
encoder = text_encoder.ByteTextEncoder()
345342
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
346-
p.vocabulary = {
347-
"targets": encoder
348-
}
343+
p.vocabulary = {"targets": encoder}
349344
p.target_space_id = 2
350345
return p
351346

@@ -358,10 +353,7 @@ def wiki_32k(model_hparams):
358353
modality_spec = (registry.Modalities.SYMBOL, encoder.vocab_size)
359354
p.input_modality = {"inputs": modality_spec}
360355
p.target_modality = modality_spec
361-
p.vocabulary = {
362-
"inputs": encoder,
363-
"targets": encoder
364-
}
356+
p.vocabulary = {"inputs": encoder, "targets": encoder}
365357
p.target_space_id = 3
366358
return p
367359

@@ -430,9 +422,7 @@ def wmt_parsing_tokens(model_hparams, wrong_vocab_size):
430422
return p
431423

432424

433-
def wsj_parsing_tokens(model_hparams,
434-
prefix,
435-
wrong_source_vocab_size,
425+
def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size,
436426
wrong_target_vocab_size):
437427
"""English to parse tree translation benchmark.
438428
@@ -487,11 +477,9 @@ def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
487477
p = default_problem_hparams()
488478
# This vocab file must be present within the data directory.
489479
source_vocab_filename = os.path.join(
490-
model_hparams.data_dir,
491-
"ice_source.vocab.%d" % wrong_source_vocab_size)
492-
target_vocab_filename = os.path.join(
493-
model_hparams.data_dir,
494-
"ice_target.vocab.256")
480+
model_hparams.data_dir, "ice_source.vocab.%d" % wrong_source_vocab_size)
481+
target_vocab_filename = os.path.join(model_hparams.data_dir,
482+
"ice_target.vocab.256")
495483
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
496484
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
497485
p.input_modality = {
@@ -502,7 +490,7 @@ def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
502490
"inputs": source_subtokenizer,
503491
"targets": target_subtokenizer,
504492
}
505-
p.input_space_id = 18 # Icelandic tokens
493+
p.input_space_id = 18 # Icelandic tokens
506494
p.target_space_id = 19 # Icelandic parse tokens
507495
return p
508496

@@ -534,23 +522,41 @@ def image_celeba(unused_model_hparams):
534522
# Dictionary of named hyperparameter settings for various problems.
535523
# This is only accessed through the problem_hparams function below.
536524
PROBLEM_HPARAMS_MAP = {
537-
"audio_timit_characters_tune": audio_timit_characters,
538-
"audio_timit_characters_test": audio_timit_characters,
539-
"audio_timit_tokens_8k_tune": lambda p: audio_timit_tokens(p, 2**13),
540-
"audio_timit_tokens_8k_test": lambda p: audio_timit_tokens(p, 2**13),
541-
"audio_wsj_characters_tune": audio_wsj_characters,
542-
"audio_wsj_characters_test": audio_wsj_characters,
543-
"audio_wsj_tokens_8k_tune": lambda p: audio_wsj_tokens(p, 2**13),
544-
"audio_wsj_tokens_8k_test": lambda p: audio_wsj_tokens(p, 2**13),
545-
"lm1b_characters": lm1b_characters,
546-
"lm1b_32k": lm1b_32k,
547-
"wiki_32k": wiki_32k,
548-
"ice_parsing_characters": wmt_parsing_characters,
549-
"ice_parsing_tokens": lambda p: ice_parsing_tokens(p, 2**13),
550-
"wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13),
551-
"wsj_parsing_tokens_16k": lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
552-
p, "wsj", 2**14, 2**9),
553-
"wmt_ende_bpe32k": wmt_ende_bpe32k,
554-
"image_celeba_tune": image_celeba,
555-
"img2img_imagenet": img2img_imagenet,
525+
"audio_timit_characters_tune":
526+
audio_timit_characters,
527+
"audio_timit_characters_test":
528+
audio_timit_characters,
529+
"audio_timit_tokens_8k_tune":
530+
lambda p: audio_timit_tokens(p, 2**13),
531+
"audio_timit_tokens_8k_test":
532+
lambda p: audio_timit_tokens(p, 2**13),
533+
"audio_wsj_characters_tune":
534+
audio_wsj_characters,
535+
"audio_wsj_characters_test":
536+
audio_wsj_characters,
537+
"audio_wsj_tokens_8k_tune":
538+
lambda p: audio_wsj_tokens(p, 2**13),
539+
"audio_wsj_tokens_8k_test":
540+
lambda p: audio_wsj_tokens(p, 2**13),
541+
"lm1b_characters":
542+
lm1b_characters,
543+
"lm1b_32k":
544+
lm1b_32k,
545+
"wiki_32k":
546+
wiki_32k,
547+
"ice_parsing_characters":
548+
wmt_parsing_characters,
549+
"ice_parsing_tokens":
550+
lambda p: ice_parsing_tokens(p, 2**13),
551+
"wmt_parsing_tokens_8k":
552+
lambda p: wmt_parsing_tokens(p, 2**13),
553+
"wsj_parsing_tokens_16k":
554+
lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda
555+
p, "wsj", 2**14, 2**9),
556+
"wmt_ende_bpe32k":
557+
wmt_ende_bpe32k,
558+
"image_celeba_tune":
559+
image_celeba,
560+
"img2img_imagenet":
561+
img2img_imagenet,
556562
}

0 commit comments

Comments
 (0)