Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f44f51f

Browse files
committed
Add shake-shake for CIFAR-10
1 parent 28e0e4e commit f44f51f

File tree

4 files changed

+111
-0
lines changed

4 files changed

+111
-0
lines changed

tensor2tensor/models/common_hparams.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def basic_params1():
6161
weight_noise=0.0,
6262
learning_rate_decay_scheme="none",
6363
learning_rate_warmup_steps=100,
64+
learning_rate_cosine_cycle_steps=250000,
6465
learning_rate=0.1,
6566
sampling_method="argmax", # "argmax" or "random"
6667
problem_choice="adaptive", # "uniform", "adaptive", "distributed"

tensor2tensor/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tensor2tensor.models import modalities
3131
from tensor2tensor.models import multimodel
3232
from tensor2tensor.models import neural_gpu
33+
from tensor2tensor.models import shake_shake
3334
from tensor2tensor.models import slicenet
3435
from tensor2tensor.models import transformer
3536
from tensor2tensor.models import transformer_alternative
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
from six.moves import xrange # pylint: disable=redefined-builtin
6+
7+
from tensor2tensor.models import common_hparams
8+
from tensor2tensor.models import common_layers
9+
from tensor2tensor.utils import registry
10+
from tensor2tensor.utils import t2t_model
11+
12+
import tensorflow as tf
13+
14+
def shake_shake_block_branch(x, conv_filters, stride):
15+
x = tf.nn.relu(x)
16+
x = common_layers.conv(x, conv_filters, (3, 3), (stride, stride))
17+
x = tf.layers.batch_normalization(x)
18+
x = tf.nn.relu(x)
19+
x = common_layers.conv(x, conv_filters, (3, 3), (1, 1))
20+
x = tf.layers.batch_normalization(x)
21+
return x
22+
23+
24+
def downsampling_residual_branch(x, conv_filters):
25+
x = tf.nn.relu(x)
26+
27+
x1 = tf.layers.average_pooling2d(x, pool_size=(1, 1), strides=(2, 2))
28+
x1 = common_layers.conv(x1, conv_filters / 2, (1, 1))
29+
30+
x2 = tf.pad(x[:, 1:, 1:], [[0, 0], [0, 1], [0, 1], [0, 0]])
31+
x2 = tf.layers.average_pooling2d(x2, pool_size=(1, 1), strides=(2, 2))
32+
x2 = common_layers.conv(x2, conv_filters / 2, (1, 1))
33+
34+
return tf.concat([x1, x2], axis=3)
35+
36+
37+
def shake_shake_block(x, conv_filters, stride):
38+
branch1 = shake_shake_block_branch(x, conv_filters, stride)
39+
branch2 = shake_shake_block_branch(x, conv_filters, stride)
40+
if x.shape[-1] == conv_filters:
41+
skip = tf.identity(x)
42+
else:
43+
skip = downsampling_residual_block(x)
44+
45+
# TODO(rshin): Set equal=true when testing.
46+
# TODO(rshin): Use different alpha for each image in batch.
47+
return skip + common_layers.shakeshake2(branch1, branch2)
48+
49+
50+
def shake_shake_stage(x, num_blocks, conv_filters, initial_stride):
51+
x = shake_shake_block(x, conv_filters, initial_stride)
52+
for _ in xrange(num_blocks - 1):
53+
x = shake_shake_block(x, conv_filters, 1)
54+
return x
55+
56+
57+
@registry.register_model
58+
class ShakeShake(t2t_model.T2TModel):
59+
60+
def model_fn_body(self, features):
61+
hparams = self._hparams
62+
63+
inputs = features["inputs"]
64+
assert (hparams.num_hidden_layers - 2) % 6 == 0
65+
blocks_per_stage = (hparams.num_hidden_layers - 2) / 6
66+
67+
# For canonical Shake-Shake, the entry flow is a 3x3 convolution with 16
68+
# filters then a batch norm. Instead we use the one in SmallImageModality,
69+
# which also seems to include a layer norm.
70+
x = inputs
71+
with tf.name_scope('shake_shake_stage_1'):
72+
x = shake_shake_stage(x, hparams.base_filters, blocks_per_stage)
73+
with tf.name_scope('shake_shake_stage_2'):
74+
x = shake_shake_stage(x, hparams.base_filters * 2, blocks_per_stage)
75+
with tf.name_scope('shake_shake_stage_3'):
76+
x = shake_shake_stage(x, hparams.base_filters * 4, blocks_per_stage)
77+
78+
# For canonical Shake-Shake, we should perform 8x8 average pooling and then
79+
# have a fully-connected layer (which produces the logits for each class).
80+
# Instead, we just use the Xception exit flow in ClassLabelModality.
81+
return x
82+
83+
@registry.register_hparams
84+
def shakeshake_cifar10():
85+
hparams = common_hparams.basic_params1()
86+
# This leads to effective batch size 128 when number of GPUs is 2
87+
hparams.batch_size = 4096 * 4
88+
hparams.hidden_size = 16
89+
hparams.dropout = 0
90+
hparams.label_smoothing = 0.0
91+
hparams.clip_grad_norm = 2.0
92+
hparams.num_hidden_layers = 26
93+
hparams.kernel_height = -1 # Unused
94+
hparams.kernel_width = -1 # Unused
95+
hparams.learning_rate_decay_scheme = "cosine"
96+
# Model should be run for 700000 steps with batch size 128 (~1800 epochs)
97+
hparams.learning_rate_cosine_cycle_steps = 700000
98+
hparams.learning_rate = 0.2
99+
hparams.learning_rate_warmup_steps = 3000
100+
hparams.initializer = "uniform_unit_scaling"
101+
hparams.initializer_gain = 1.0
102+
hparams.weight_decay = 0.1 # Effective value should be ~1e-4
103+
hparams.optimizer = "Momentum"
104+
hparams.optimizer_momentum_momentum = 0.9
105+
hparams.add_hparam('base_filters', 16)
106+
return hparams

tensor2tensor/utils/trainer_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ def learning_rate_decay():
321321
(step + 1) * warmup_steps**-1.5, (step + 1)**-0.5)
322322
elif hparams.learning_rate_decay_scheme == "exp100k":
323323
return 0.94**(step // 100000)
324+
elif hparams.learning_rate_decay_scheme == "cosine":
325+
cycle_steps = hparams.learning_rate_cosine_cycle_steps
326+
return 0.5 * (1 + tf.cos(np.pi * (step % cycle_steps) / cycle_steps))
324327

325328
inv_base = tf.exp(tf.log(0.01) / warmup_steps)
326329
inv_decay = inv_base**(warmup_steps - step)

0 commit comments

Comments
 (0)