Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 9f02a51

Browse files
committed
Fix bugs, add more explanation
1 parent f44f51f commit 9f02a51

File tree

2 files changed

+57
-25
lines changed

2 files changed

+57
-25
lines changed

tensor2tensor/models/common_layers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def inverse_exp_decay(max_step, min_value=0.01):
5858
return inv_base**tf.maximum(float(max_step) - step, 0.0)
5959

6060

61-
def shakeshake2_py(x, y, equal=False):
61+
def shakeshake2_py(x, y, equal=False, individual=False):
6262
"""The shake-shake sum of 2 tensors, python version."""
6363
alpha = 0.5 if equal else tf.random_uniform([])
6464
return alpha * x + (1.0 - alpha) * y
@@ -85,6 +85,11 @@ def shakeshake2(x1, x2):
8585
"""The shake-shake function with a different alpha for forward/backward."""
8686
return shakeshake2_py(x1, x2)
8787

88+
@function.Defun(grad_func=shakeshake2_grad)
89+
def shakeshake2_eqforward(x1, x2):
90+
"""The shake-shake function with a different alpha for forward/backward."""
91+
return shakeshake2_py(x1, x2, equal=True)
92+
8893

8994
@function.Defun(grad_func=shakeshake2_equal_grad)
9095
def shakeshake2_eqgrad(x1, x2):

tensor2tensor/models/shake_shake.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111

1212
import tensorflow as tf
1313

14+
1415
def shake_shake_block_branch(x, conv_filters, stride):
1516
x = tf.nn.relu(x)
16-
x = common_layers.conv(x, conv_filters, (3, 3), (stride, stride))
17+
x = tf.layers.conv2d(
18+
x, conv_filters, (3, 3), strides=(stride, stride), padding='SAME')
1719
x = tf.layers.batch_normalization(x)
1820
x = tf.nn.relu(x)
19-
x = common_layers.conv(x, conv_filters, (3, 3), (1, 1))
21+
x = tf.layers.conv2d(x, conv_filters, (3, 3), strides=(1, 1), padding='SAME')
2022
x = tf.layers.batch_normalization(x)
2123
return x
2224

@@ -25,66 +27,90 @@ def downsampling_residual_branch(x, conv_filters):
2527
x = tf.nn.relu(x)
2628

2729
x1 = tf.layers.average_pooling2d(x, pool_size=(1, 1), strides=(2, 2))
28-
x1 = common_layers.conv(x1, conv_filters / 2, (1, 1))
30+
x1 = tf.layers.conv2d(x1, conv_filters / 2, (1, 1), padding='SAME')
2931

3032
x2 = tf.pad(x[:, 1:, 1:], [[0, 0], [0, 1], [0, 1], [0, 0]])
3133
x2 = tf.layers.average_pooling2d(x2, pool_size=(1, 1), strides=(2, 2))
32-
x2 = common_layers.conv(x2, conv_filters / 2, (1, 1))
34+
x2 = tf.layers.conv2d(x2, conv_filters / 2, (1, 1), padding='SAME')
3335

3436
return tf.concat([x1, x2], axis=3)
3537

3638

37-
def shake_shake_block(x, conv_filters, stride):
38-
branch1 = shake_shake_block_branch(x, conv_filters, stride)
39-
branch2 = shake_shake_block_branch(x, conv_filters, stride)
39+
def shake_shake_block(x, conv_filters, stride, mode):
40+
with tf.variable_scope('branch_1'):
41+
branch1 = shake_shake_block_branch(x, conv_filters, stride)
42+
with tf.variable_scope('branch_2'):
43+
branch2 = shake_shake_block_branch(x, conv_filters, stride)
4044
if x.shape[-1] == conv_filters:
4145
skip = tf.identity(x)
4246
else:
43-
skip = downsampling_residual_block(x)
47+
skip = downsampling_residual_branch(x, conv_filters)
4448

45-
# TODO(rshin): Set equal=true when testing.
4649
# TODO(rshin): Use different alpha for each image in batch.
47-
return skip + common_layers.shakeshake2(branch1, branch2)
50+
if mode == tf.contrib.learn.ModeKeys.TRAIN:
51+
shaken = common_layers.shakeshake2(branch1, branch2)
52+
else:
53+
shaken = common_layers.shakeshake2_eqforward(branch1, branch2)
54+
shaken.set_shape(branch1.get_shape())
55+
56+
return skip + shaken
4857

4958

50-
def shake_shake_stage(x, num_blocks, conv_filters, initial_stride):
51-
x = shake_shake_block(x, conv_filters, initial_stride)
52-
for _ in xrange(num_blocks - 1):
53-
x = shake_shake_block(x, conv_filters, 1)
59+
def shake_shake_stage(x, num_blocks, conv_filters, initial_stride, mode):
60+
with tf.variable_scope('block_0'):
61+
x = shake_shake_block(x, conv_filters, initial_stride, mode)
62+
for i in xrange(1, num_blocks):
63+
with tf.variable_scope('block_{}'.format(i)):
64+
x = shake_shake_block(x, conv_filters, 1, mode)
5465
return x
5566

5667

5768
@registry.register_model
5869
class ShakeShake(t2t_model.T2TModel):
70+
'''Implements the Shake-Shake architecture.
71+
72+
From <https://arxiv.org/pdf/1705.07485.pdf>
73+
This is intended to match the CIFAR-10 version, and correspond to
74+
"Shake-Shake-Batch" in Table 1.
75+
'''
5976

6077
def model_fn_body(self, features):
6178
hparams = self._hparams
6279

6380
inputs = features["inputs"]
6481
assert (hparams.num_hidden_layers - 2) % 6 == 0
65-
blocks_per_stage = (hparams.num_hidden_layers - 2) / 6
82+
blocks_per_stage = (hparams.num_hidden_layers - 2) // 6
6683

6784
# For canonical Shake-Shake, the entry flow is a 3x3 convolution with 16
6885
# filters then a batch norm. Instead we use the one in SmallImageModality,
6986
# which also seems to include a layer norm.
7087
x = inputs
71-
with tf.name_scope('shake_shake_stage_1'):
72-
x = shake_shake_stage(x, hparams.base_filters, blocks_per_stage)
73-
with tf.name_scope('shake_shake_stage_2'):
74-
x = shake_shake_stage(x, hparams.base_filters * 2, blocks_per_stage)
75-
with tf.name_scope('shake_shake_stage_3'):
76-
x = shake_shake_stage(x, hparams.base_filters * 4, blocks_per_stage)
88+
mode = hparams.mode
89+
with tf.variable_scope('shake_shake_stage_1'):
90+
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters, 1, mode)
91+
with tf.variable_scope('shake_shake_stage_2'):
92+
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 2, 2,
93+
mode)
94+
with tf.variable_scope('shake_shake_stage_3'):
95+
x = shake_shake_stage(x, blocks_per_stage, hparams.base_filters * 4, 2,
96+
mode)
7797

7898
# For canonical Shake-Shake, we should perform 8x8 average pooling and then
7999
# have a fully-connected layer (which produces the logits for each class).
80100
# Instead, we just use the Xception exit flow in ClassLabelModality.
101+
#
102+
# Also, this model_fn does not return an extra_loss. However, TensorBoard
103+
# reports an exponential moving average for extra_loss, where the initial
104+
# value for the moving average may be a large number, so extra_loss will
105+
# look large at the beginning of training.
81106
return x
82107

108+
83109
@registry.register_hparams
84110
def shakeshake_cifar10():
85111
hparams = common_hparams.basic_params1()
86-
# This leads to effective batch size 128 when number of GPUs is 2
87-
hparams.batch_size = 4096 * 4
112+
# This leads to effective batch size 128 when number of GPUs is 1
113+
hparams.batch_size = 4096 * 8
88114
hparams.hidden_size = 16
89115
hparams.dropout = 0
90116
hparams.label_smoothing = 0.0
@@ -99,7 +125,8 @@ def shakeshake_cifar10():
99125
hparams.learning_rate_warmup_steps = 3000
100126
hparams.initializer = "uniform_unit_scaling"
101127
hparams.initializer_gain = 1.0
102-
hparams.weight_decay = 0.1 # Effective value should be ~1e-4
128+
# TODO(rshin): Adjust so that effective value becomes ~1e-4
129+
hparams.weight_decay = 3.0
103130
hparams.optimizer = "Momentum"
104131
hparams.optimizer_momentum_momentum = 0.9
105132
hparams.add_hparam('base_filters', 16)

0 commit comments

Comments
 (0)