Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c2ce7a6

Browse files
Lukasz KaiserRyan Sepassi
authored andcommitted
Play with CIFAR models and shake-shake a little.
PiperOrigin-RevId: 160016542
1 parent 9594213 commit c2ce7a6

File tree

8 files changed

+271
-5
lines changed

8 files changed

+271
-5
lines changed

tensor2tensor/data_generators/image.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,6 @@ def cifar10_generator(tmp_dir, training, how_many, start_from=0):
200200
])
201201
labels = data["labels"]
202202
all_labels.extend([labels[j] for j in xrange(num_images)])
203-
# Shuffle the data to make sure classes are well distributed.
204-
data = zip(all_images, all_labels)
205-
random.shuffle(data)
206-
all_images, all_labels = zip(*data)
207203
return image_generator(all_images[start_from:start_from + how_many],
208204
all_labels[start_from:start_from + how_many])
209205

tensor2tensor/models/bluenet.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""BlueNet: and out of the blue network to experiment with shake-shake."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
# Dependency imports
22+
23+
from six.moves import xrange # pylint: disable=redefined-builtin
24+
25+
from tensor2tensor.models import common_hparams
26+
from tensor2tensor.models import common_layers
27+
from tensor2tensor.utils import registry
28+
from tensor2tensor.utils import t2t_model
29+
30+
import tensorflow as tf
31+
32+
33+
def residual_module(x, hparams, train, n, sep):
34+
"""A stack of convolution blocks with residual connection."""
35+
k = (hparams.kernel_height, hparams.kernel_width)
36+
dilations_and_kernels = [((1, 1), k) for _ in xrange(n)]
37+
with tf.variable_scope("residual_module%d_sep%d" % (n, sep)):
38+
y = common_layers.subseparable_conv_block(
39+
x,
40+
hparams.hidden_size,
41+
dilations_and_kernels,
42+
padding="SAME",
43+
separability=sep,
44+
name="block")
45+
x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm")
46+
return tf.nn.dropout(x, 1.0 - hparams.dropout * tf.to_float(train))
47+
48+
49+
def residual_module1(x, hparams, train):
50+
return residual_module(x, hparams, train, 1, 1)
51+
52+
53+
def residual_module1_sep(x, hparams, train):
54+
return residual_module(x, hparams, train, 1, 0)
55+
56+
57+
def residual_module2(x, hparams, train):
58+
return residual_module(x, hparams, train, 2, 1)
59+
60+
61+
def residual_module2_sep(x, hparams, train):
62+
return residual_module(x, hparams, train, 2, 0)
63+
64+
65+
def residual_module3(x, hparams, train):
66+
return residual_module(x, hparams, train, 3, 1)
67+
68+
69+
def residual_module3_sep(x, hparams, train):
70+
return residual_module(x, hparams, train, 3, 0)
71+
72+
73+
def norm_module(x, hparams, train):
74+
del train # Unused.
75+
return common_layers.layer_norm(x, hparams.hidden_size, name="norm_module")
76+
77+
78+
def identity_module(x, hparams, train):
79+
del hparams, train # Unused.
80+
return x
81+
82+
83+
def run_modules(blocks, cur, hparams, train, dp):
84+
"""Run blocks in parallel using dp as data_parallelism."""
85+
assert len(blocks) % dp.n == 0
86+
res = []
87+
for i in xrange(len(blocks) // dp.n):
88+
res.extend(dp(blocks[i * dp.n:(i + 1) * dp.n], cur, hparams, train))
89+
return res
90+
91+
92+
@registry.register_model
93+
class BlueNet(t2t_model.T2TModel):
94+
95+
def model_fn_body_sharded(self, sharded_features, train):
96+
dp = self._data_parallelism
97+
dp._reuse = False # pylint:disable=protected-access
98+
hparams = self._hparams
99+
blocks = [identity_module, norm_module,
100+
residual_module1, residual_module1_sep,
101+
residual_module2, residual_module2_sep,
102+
residual_module3, residual_module3_sep]
103+
inputs = sharded_features["inputs"]
104+
105+
cur = tf.concat(inputs, axis=0)
106+
cur_shape = cur.get_shape()
107+
for i in xrange(hparams.num_hidden_layers):
108+
with tf.variable_scope("layer_%d" % i):
109+
processed = run_modules(blocks, cur, hparams, train, dp)
110+
cur = common_layers.shakeshake(processed)
111+
cur.set_shape(cur_shape)
112+
113+
return list(tf.split(cur, len(inputs), axis=0)), 0.0
114+
115+
116+
@registry.register_hparams
117+
def bluenet_base():
118+
"""Set of hyperparameters."""
119+
hparams = common_hparams.basic_params1()
120+
hparams.batch_size = 4096
121+
hparams.hidden_size = 768
122+
hparams.dropout = 0.2
123+
hparams.symbol_dropout = 0.2
124+
hparams.label_smoothing = 0.1
125+
hparams.clip_grad_norm = 2.0
126+
hparams.num_hidden_layers = 8
127+
hparams.kernel_height = 3
128+
hparams.kernel_width = 3
129+
hparams.learning_rate_decay_scheme = "exp50k"
130+
hparams.learning_rate = 0.05
131+
hparams.learning_rate_warmup_steps = 3000
132+
hparams.initializer_gain = 1.0
133+
hparams.weight_decay = 3.0
134+
hparams.num_sampled_classes = 0
135+
hparams.sampling_method = "argmax"
136+
hparams.optimizer_adam_epsilon = 1e-6
137+
hparams.optimizer_adam_beta1 = 0.85
138+
hparams.optimizer_adam_beta2 = 0.997
139+
hparams.add_hparam("imagenet_use_2d", True)
140+
return hparams
141+
142+
143+
@registry.register_hparams
144+
def bluenet_tiny():
145+
hparams = bluenet_base()
146+
hparams.batch_size = 1024
147+
hparams.hidden_size = 128
148+
hparams.num_hidden_layers = 4
149+
hparams.learning_rate_decay_scheme = "none"
150+
return hparams
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2017 Google Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""BlueNet tests."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
# Dependency imports
22+
23+
import numpy as np
24+
25+
from tensor2tensor.data_generators import problem_hparams
26+
from tensor2tensor.models import bluenet
27+
28+
import tensorflow as tf
29+
30+
31+
class BlueNetTest(tf.test.TestCase):
32+
33+
def testBlueNet(self):
34+
vocab_size = 9
35+
x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1))
36+
y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1))
37+
hparams = bluenet.bluenet_tiny()
38+
p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size,
39+
vocab_size)
40+
with self.test_session() as session:
41+
features = {
42+
"inputs": tf.constant(x, dtype=tf.int32),
43+
"targets": tf.constant(y, dtype=tf.int32),
44+
}
45+
model = bluenet.BlueNet(hparams, p_hparams)
46+
sharded_logits, _, _ = model.model_fn(features, True)
47+
logits = tf.concat(sharded_logits, 0)
48+
session.run(tf.global_variables_initializer())
49+
res = session.run(logits)
50+
self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))
51+
52+
53+
if __name__ == "__main__":
54+
tf.test.main()

tensor2tensor/models/common_layers.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,52 @@ def inverse_exp_decay(max_step, min_value=0.01):
5858
return inv_base**tf.maximum(float(max_step) - step, 0.0)
5959

6060

61+
def shakeshake2_py(x, y, equal=False):
62+
"""The shake-shake sum of 2 tensors, python version."""
63+
alpha = 0.5 if equal else tf.random_uniform([])
64+
return alpha * x + (1.0 - alpha) * y
65+
66+
67+
@function.Defun()
68+
def shakeshake2_grad(x1, x2, dy):
69+
"""Overriding gradient for shake-shake of 2 tensors."""
70+
y = shakeshake2_py(x1, x2)
71+
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
72+
return dx
73+
74+
75+
@function.Defun()
76+
def shakeshake2_equal_grad(x1, x2, dy):
77+
"""Overriding gradient for shake-shake of 2 tensors."""
78+
y = shakeshake2_py(x1, x2, equal=True)
79+
dx = tf.gradients(ys=[y], xs=[x1, x2], grad_ys=[dy])
80+
return dx
81+
82+
83+
@function.Defun(grad_func=shakeshake2_grad)
84+
def shakeshake2(x1, x2):
85+
"""The shake-shake function with a different alpha for forward/backward."""
86+
return shakeshake2_py(x1, x2)
87+
88+
89+
@function.Defun(grad_func=shakeshake2_equal_grad)
90+
def shakeshake2_eqgrad(x1, x2):
91+
"""The shake-shake function with a different alpha for forward/backward."""
92+
return shakeshake2_py(x1, x2)
93+
94+
95+
def shakeshake(xs, equal_grad=False):
96+
"""Multi-argument shake-shake, currently approximated by sums of 2."""
97+
if len(xs) == 1:
98+
return xs[0]
99+
div = (len(xs) + 1) // 2
100+
arg1 = shakeshake(xs[:div], equal_grad=equal_grad)
101+
arg2 = shakeshake(xs[div:], equal_grad=equal_grad)
102+
if equal_grad:
103+
return shakeshake2_eqgrad(arg1, arg2)
104+
return shakeshake2(arg1, arg2)
105+
106+
61107
def standardize_images(x):
62108
"""Image standardization on batches (tf.image.per_image_standardization)."""
63109
with tf.name_scope("standardize_images", [x]):

tensor2tensor/models/common_layers_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ def testEmbedding(self):
6565
res = session.run(y)
6666
self.assertEqual(res.shape, (3, 5, 16))
6767

68+
def testShakeShake(self):
69+
x = np.random.rand(5, 7)
70+
with self.test_session() as session:
71+
x = tf.constant(x, dtype=tf.float32)
72+
y = common_layers.shakeshake([x, x, x, x, x])
73+
session.run(tf.global_variables_initializer())
74+
inp, res = session.run([x, y])
75+
self.assertAllClose(res, inp)
76+
6877
def testConv(self):
6978
x = np.random.rand(5, 7, 1, 11)
7079
with self.test_session() as session:

tensor2tensor/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from tensor2tensor.models import attention_lm
2626
from tensor2tensor.models import attention_lm_moe
27+
from tensor2tensor.models import bluenet
2728
from tensor2tensor.models import bytenet
2829
from tensor2tensor.models import lstm
2930
from tensor2tensor.models import modalities

tensor2tensor/models/xception.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,13 @@ def xception_base():
8787
hparams.optimizer_adam_beta2 = 0.997
8888
hparams.add_hparam("imagenet_use_2d", True)
8989
return hparams
90+
91+
92+
@registry.register_hparams
93+
def xception_tiny():
94+
hparams = xception_base()
95+
hparams.batch_size = 1024
96+
hparams.hidden_size = 128
97+
hparams.num_hidden_layers = 4
98+
hparams.learning_rate_decay_scheme = "none"
99+
return hparams

tensor2tensor/models/xception_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def testXception(self):
3434
vocab_size = 9
3535
x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1))
3636
y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1))
37-
hparams = xception.xception_base()
37+
hparams = xception.xception_tiny()
3838
p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size,
3939
vocab_size)
4040
with self.test_session() as session:

0 commit comments

Comments
 (0)