Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 6cea4c4

Browse files
authored
Merge pull request #765 from lukaszkaiser/push
Merging internal updates.
2 parents 40758df + c98fab4 commit 6cea4c4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2316
-690
lines changed

.travis.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ env:
1111
- TF_VERSION="1.5.*"
1212
- TF_VERSION="1.6.*"
1313
- TF_VERSION="1.7.*"
14+
- TF_VERSION="1.8.*"
1415
matrix:
1516
exclude:
1617
- python: "3.6"
1718
env: TF_VERSION="1.5.*"
1819
- python: "3.6"
1920
env: TF_VERSION="1.6.*"
21+
- python: "3.6"
22+
env: TF_VERSION="1.7.*"
2023
before_install:
2124
- echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
2225
- curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ You can chat with us on
2626

2727
### Quick Start
2828

29-
[This iPython notebook](https://goo.gl/wkHexj) explains T2T and runs in your
30-
browser using a free VM from Google, no installation needed.
31-
Alternatively, here is a one-command version that installs T2T, downloads MNIST,
32-
trains a model and evaluates it:
29+
[This iPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb)
30+
explains T2T and runs in your browser using a free VM from Google,
31+
no installation needed. Alternatively, here is a one-command version that
32+
installs T2T, downloads MNIST, trains a model and evaluates it:
3333

3434
```
3535
pip install tensor2tensor && t2t-trainer \

docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ research](https://research.googleblog.com/2017/06/accelerating-deep-learning-res
1919
## Basics
2020

2121
* [Walkthrough](walkthrough.md): Install and run.
22-
* [IPython notebook](https://goo.gl/wkHexj): Get a hands-on experience.
22+
* [IPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb): Get a hands-on experience.
2323
* [Overview](overview.md): How all parts of T2T code are connected.
2424
* [New Problem](new_problem.md): Train T2T models on your data.
2525
* [New Model](new_model.md): Create your own T2T model.

docs/walkthrough.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ You can chat with us on
2626

2727
### Quick Start
2828

29-
[This iPython notebook](https://goo.gl/wkHexj) explains T2T and runs in your
30-
browser using a free VM from Google, no installation needed.
31-
Alternatively, here is a one-command version that installs T2T, downloads MNIST,
32-
trains a model and evaluates it:
29+
[This iPython notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb)
30+
explains T2T and runs in your browser using a free VM from Google,
31+
no installation needed. Alternatively, here is a one-command version that
32+
installs T2T, downloads MNIST, trains a model and evaluates it:
3333

3434
```
3535
pip install tensor2tensor && t2t-trainer \

tensor2tensor/data_generators/all_problems.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"tensor2tensor.data_generators.ptb",
4747
"tensor2tensor.data_generators.snli",
4848
"tensor2tensor.data_generators.squad",
49+
"tensor2tensor.data_generators.subject_verb_agreement",
4950
"tensor2tensor.data_generators.translate_encs",
5051
"tensor2tensor.data_generators.translate_ende",
5152
"tensor2tensor.data_generators.translate_enet",
@@ -56,6 +57,7 @@
5657
"tensor2tensor.data_generators.twentybn",
5758
"tensor2tensor.data_generators.wiki",
5859
"tensor2tensor.data_generators.wikisum.wikisum",
60+
"tensor2tensor.data_generators.wikitext103",
5961
"tensor2tensor.data_generators.wsj_parsing",
6062
]
6163

tensor2tensor/data_generators/generator_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def to_example(dictionary):
4444
features = {}
4545
for (k, v) in six.iteritems(dictionary):
4646
if not v:
47-
raise ValueError("Empty generated field: %s", str((k, v)))
47+
raise ValueError("Empty generated field: %s" % str((k, v)))
4848
if isinstance(v[0], six.integer_types):
4949
features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
5050
elif isinstance(v[0], float):
@@ -130,7 +130,8 @@ def outputs_exist(filenames):
130130
return out_fname
131131

132132

133-
def generate_files(generator, output_filenames, max_cases=None):
133+
def generate_files(generator, output_filenames,
134+
max_cases=None, cycle_every_n=1):
134135
"""Generate cases from a generator and save as TFRecord files.
135136
136137
Generated cases are transformed to tf.Example protos and saved as TFRecords
@@ -141,6 +142,8 @@ def generate_files(generator, output_filenames, max_cases=None):
141142
output_filenames: List of output file paths.
142143
max_cases: maximum number of cases to get from the generator;
143144
if None (default), we use the generator until StopIteration is raised.
145+
cycle_every_n: how many cases from the generator to take before
146+
switching to the next shard; by default set to 1, switch every case.
144147
"""
145148
if outputs_exist(output_filenames):
146149
tf.logging.info("Skipping generator because outputs files exist")
@@ -159,7 +162,8 @@ def generate_files(generator, output_filenames, max_cases=None):
159162
break
160163
example = to_example(case)
161164
writers[shard].write(example.SerializeToString())
162-
shard = (shard + 1) % num_shards
165+
if counter % cycle_every_n == 0:
166+
shard = (shard + 1) % num_shards
163167

164168
for writer in writers:
165169
writer.close()
@@ -341,6 +345,7 @@ def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size,
341345
"""Generate a vocabulary from the datasets in sources."""
342346

343347
def generate():
348+
"""Generate lines for vocabulary generation."""
344349
tf.logging.info("Generating vocab from: %s", str(sources))
345350
for source in sources:
346351
url = source[0]

tensor2tensor/data_generators/gym.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21-
from collections import deque
22-
2321
import functools
22+
import os
23+
2424
# Dependency imports
25+
2526
import gym
2627

2728
from tensor2tensor.data_generators import problem
@@ -62,9 +63,7 @@ def num_target_frames(self):
6263
return 1
6364

6465
def eval_metrics(self):
65-
eval_metrics = [
66-
metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ,
67-
metrics.Metrics.NEG_LOG_PERPLEXITY]
66+
eval_metrics = [metrics.Metrics.ACC, metrics.Metrics.ACC_PER_SEQ]
6867
return eval_metrics
6968

7069
@property
@@ -108,6 +107,10 @@ def num_rewards(self):
108107
def num_steps(self):
109108
raise NotImplementedError()
110109

110+
@property
111+
def total_number_of_frames(self):
112+
return self.num_steps
113+
111114
@property
112115
def min_reward(self):
113116
raise NotImplementedError()
@@ -126,13 +129,13 @@ def hparams(self, defaults, unused_model_hparams):
126129
p.target_space_id = problem.SpaceID.IMAGE
127130

128131
def generate_samples(self, data_dir, tmp_dir, unused_dataset_split):
129-
next_obs = self.env.reset()
132+
next_observation = self.env.reset()
130133
for _ in range(self.num_steps):
131-
observation = next_obs
134+
observation = next_observation
132135
action = self.get_action(observation)
133-
next_obs, reward, done, _ = self.env.step(action)
136+
next_observation, reward, done, _ = self.env.step(action)
134137
if done:
135-
next_obs = self.env.reset()
138+
next_observation = self.env.reset()
136139
yield {"frame": observation,
137140
"action": [action],
138141
"done": [done],
@@ -184,23 +187,22 @@ class GymDiscreteProblemWithAgent(GymPongRandom5k):
184187
def __init__(self, *args, **kwargs):
185188
super(GymDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
186189
self._env = None
187-
self.history_size = 2
190+
self.debug_dump_frames_path = "debug_frames_env"
188191

189192
# defaults
190193
self.environment_spec = lambda: gym.make("PongDeterministic-v4")
191-
self.in_graph_wrappers = [(atari.MaxAndSkipWrapper, {"skip": 4})]
194+
self.in_graph_wrappers = []
192195
self.collect_hparams = rl.atari_base()
193-
self.settable_num_steps = 1000
196+
self.settable_num_steps = 20000
194197
self.simulated_environment = None
195-
self.warm_up = 70
198+
self.warm_up = 10
196199

197200
@property
198201
def num_steps(self):
199202
return self.settable_num_steps
200203

201204
def _setup(self):
202-
in_graph_wrappers = [(atari.ShiftRewardWrapper, {"add_value": 2}),
203-
(atari.MemoryWrapper, {})] + self.in_graph_wrappers
205+
in_graph_wrappers = [(atari.MemoryWrapper, {})] + self.in_graph_wrappers
204206
env_hparams = tf.contrib.training.HParams(
205207
in_graph_wrappers=in_graph_wrappers,
206208
simulated_environment=self.simulated_environment)
@@ -229,41 +231,41 @@ def _setup(self):
229231

230232
self.avilable_data_size_op = atari.MemoryWrapper.singleton.speculum.size()
231233
self.data_get_op = atari.MemoryWrapper.singleton.speculum.dequeue()
232-
self.history_buffer = deque(maxlen=self.history_size+1)
233234

234235
def restore_networks(self, sess):
235236
if FLAGS.agent_policy_path:
236237
model_saver = tf.train.Saver(
237-
tf.global_variables(".*network_parameters.*"))
238+
tf.global_variables(".*network_parameters.*"))
238239
model_saver.restore(sess, FLAGS.agent_policy_path)
239240

240241
def generate_encoded_samples(self, data_dir, tmp_dir, unused_dataset_split):
241242
self._setup()
243+
self.debug_dump_frames_path = os.path.join(
244+
data_dir, self.debug_dump_frames_path)
242245

243246
with tf.Session() as sess:
244247
sess.run(tf.global_variables_initializer())
245248
self.restore_networks(sess)
246-
249+
# Actions are shifted by 1 by MemoryWrapper, compensate here.
250+
avilable_data_size = sess.run(self.avilable_data_size_op)
251+
if avilable_data_size < 1:
252+
sess.run(self.collect_trigger_op)
247253
pieces_generated = 0
254+
observ, reward, _, _ = sess.run(self.data_get_op)
248255
while pieces_generated < self.num_steps + self.warm_up:
249256
avilable_data_size = sess.run(self.avilable_data_size_op)
250-
if avilable_data_size > 0:
251-
observ, reward, action, _ = sess.run(self.data_get_op)
252-
self.history_buffer.append(observ)
253-
254-
if len(self.history_buffer) == self.history_size + 1:
255-
pieces_generated += 1
256-
ret_dict = {"image/encoded": [observ],
257-
"image/format": ["png"],
258-
"image/height": [self.frame_height],
259-
"image/width": [self.frame_width],
260-
"action": [int(action)],
261-
"done": [int(False)],
262-
"reward": [int(reward) - self.min_reward]}
263-
if pieces_generated > self.warm_up:
264-
yield ret_dict
265-
else:
257+
if avilable_data_size < 1:
266258
sess.run(self.collect_trigger_op)
259+
next_observ, next_reward, action, _ = sess.run(self.data_get_op)
260+
yield {"image/encoded": [observ],
261+
"image/format": ["png"],
262+
"image/height": [self.frame_height],
263+
"image/width": [self.frame_width],
264+
"action": [int(action)],
265+
"done": [int(False)],
266+
"reward": [int(reward) - self.min_reward]}
267+
pieces_generated += 1
268+
observ, reward = next_observ, next_reward
267269

268270

269271
@registry.register_problem
@@ -273,7 +275,7 @@ class GymSimulatedDiscreteProblemWithAgent(GymDiscreteProblemWithAgent):
273275
def __init__(self, *args, **kwargs):
274276
super(GymSimulatedDiscreteProblemWithAgent, self).__init__(*args, **kwargs)
275277
self.simulated_environment = True
276-
self.debug_dump_frames_path = "/tmp/t2t_debug_dump_frames"
278+
self.debug_dump_frames_path = "debug_frames_sim"
277279

278280
def restore_networks(self, sess):
279281
super(GymSimulatedDiscreteProblemWithAgent, self).restore_networks(sess)

tensor2tensor/data_generators/imagenet.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def preprocess_example(self, example, mode, _):
189189

190190
@registry.register_problem
191191
class ImageImagenet64Gen(ImageImagenet):
192-
"""Cifar-10 Tune."""
192+
"""Imagenet 64 from the pixen cnn paper"""
193193

194194
@property
195195
def train_shards(self):
@@ -264,6 +264,33 @@ def preprocess_example(self, example, mode, hparams):
264264
return example
265265

266266

267+
@registry.register_problem
268+
class ImageImagenet32Small(ImageImagenet):
269+
"""Imagenet small from the pixel cnn paper"""
270+
271+
@property
272+
def is_small(self):
273+
return False # Modalities like for CIFAR.
274+
275+
@property
276+
def num_classes(self):
277+
return 1000
278+
279+
@property
280+
def train_shards(self):
281+
return 1024
282+
283+
@property
284+
def dev_shards(self):
285+
return 10
286+
287+
def preprocess_example(self, example, mode, unused_hparams):
288+
example["inputs"].set_shape([_IMAGENET_SMALL_IMAGE_SIZE,
289+
_IMAGENET_SMALL_IMAGE_SIZE, 3])
290+
example["inputs"] = tf.to_int64(example["inputs"])
291+
return example
292+
293+
267294
@registry.register_problem
268295
class ImageImagenet64(ImageImagenet32):
269296
"""Imagenet rescaled to 64x64."""

tensor2tensor/data_generators/squad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,5 @@ def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):
143143
for sample in samples:
144144
sample['targets'] = self.generate_targets(sample['targets'],
145145
sample['context'])
146-
if not sample['targets']:
146+
if sample['targets']:
147147
yield sample

0 commit comments

Comments
 (0)