1111
1212import tensorflow as tf
1313
14+
1415def shake_shake_block_branch (x , conv_filters , stride ):
1516 x = tf .nn .relu (x )
16- x = common_layers .conv (x , conv_filters , (3 , 3 ), (stride , stride ))
17+ x = tf .layers .conv2d (
18+ x , conv_filters , (3 , 3 ), strides = (stride , stride ), padding = 'SAME' )
1719 x = tf .layers .batch_normalization (x )
1820 x = tf .nn .relu (x )
19- x = common_layers . conv (x , conv_filters , (3 , 3 ), (1 , 1 ))
21+ x = tf . layers . conv2d (x , conv_filters , (3 , 3 ), strides = (1 , 1 ), padding = 'SAME' )
2022 x = tf .layers .batch_normalization (x )
2123 return x
2224
@@ -25,66 +27,90 @@ def downsampling_residual_branch(x, conv_filters):
2527 x = tf .nn .relu (x )
2628
2729 x1 = tf .layers .average_pooling2d (x , pool_size = (1 , 1 ), strides = (2 , 2 ))
28- x1 = common_layers . conv (x1 , conv_filters / 2 , (1 , 1 ))
30+ x1 = tf . layers . conv2d (x1 , conv_filters / 2 , (1 , 1 ), padding = 'SAME' )
2931
3032 x2 = tf .pad (x [:, 1 :, 1 :], [[0 , 0 ], [0 , 1 ], [0 , 1 ], [0 , 0 ]])
3133 x2 = tf .layers .average_pooling2d (x2 , pool_size = (1 , 1 ), strides = (2 , 2 ))
32- x2 = common_layers . conv (x2 , conv_filters / 2 , (1 , 1 ))
34+ x2 = tf . layers . conv2d (x2 , conv_filters / 2 , (1 , 1 ), padding = 'SAME' )
3335
3436 return tf .concat ([x1 , x2 ], axis = 3 )
3537
3638
37- def shake_shake_block (x , conv_filters , stride ):
38- branch1 = shake_shake_block_branch (x , conv_filters , stride )
39- branch2 = shake_shake_block_branch (x , conv_filters , stride )
39+ def shake_shake_block (x , conv_filters , stride , mode ):
40+ with tf .variable_scope ('branch_1' ):
41+ branch1 = shake_shake_block_branch (x , conv_filters , stride )
42+ with tf .variable_scope ('branch_2' ):
43+ branch2 = shake_shake_block_branch (x , conv_filters , stride )
4044 if x .shape [- 1 ] == conv_filters :
4145 skip = tf .identity (x )
4246 else :
43- skip = downsampling_residual_block ( x )
47+ skip = downsampling_residual_branch ( x , conv_filters )
4448
45- # TODO(rshin): Set equal=true when testing.
4649 # TODO(rshin): Use different alpha for each image in batch.
47- return skip + common_layers .shakeshake2 (branch1 , branch2 )
50+ if mode == tf .contrib .learn .ModeKeys .TRAIN :
51+ shaken = common_layers .shakeshake2 (branch1 , branch2 )
52+ else :
53+ shaken = common_layers .shakeshake2_eqforward (branch1 , branch2 )
54+ shaken .set_shape (branch1 .get_shape ())
55+
56+ return skip + shaken
4857
4958
50- def shake_shake_stage (x , num_blocks , conv_filters , initial_stride ):
51- x = shake_shake_block (x , conv_filters , initial_stride )
52- for _ in xrange (num_blocks - 1 ):
53- x = shake_shake_block (x , conv_filters , 1 )
59+ def shake_shake_stage (x , num_blocks , conv_filters , initial_stride , mode ):
60+ with tf .variable_scope ('block_0' ):
61+ x = shake_shake_block (x , conv_filters , initial_stride , mode )
62+ for i in xrange (1 , num_blocks ):
63+ with tf .variable_scope ('block_{}' .format (i )):
64+ x = shake_shake_block (x , conv_filters , 1 , mode )
5465 return x
5566
5667
5768@registry .register_model
5869class ShakeShake (t2t_model .T2TModel ):
70+ '''Implements the Shake-Shake architecture.
71+
72+ From <https://arxiv.org/pdf/1705.07485.pdf>
73+ This is intended to match the CIFAR-10 version, and correspond to
74+ "Shake-Shake-Batch" in Table 1.
75+ '''
5976
6077 def model_fn_body (self , features ):
6178 hparams = self ._hparams
6279
6380 inputs = features ["inputs" ]
6481 assert (hparams .num_hidden_layers - 2 ) % 6 == 0
65- blocks_per_stage = (hparams .num_hidden_layers - 2 ) / 6
82+ blocks_per_stage = (hparams .num_hidden_layers - 2 ) // 6
6683
6784 # For canonical Shake-Shake, the entry flow is a 3x3 convolution with 16
6885 # filters then a batch norm. Instead we use the one in SmallImageModality,
6986 # which also seems to include a layer norm.
7087 x = inputs
71- with tf .name_scope ('shake_shake_stage_1' ):
72- x = shake_shake_stage (x , hparams .base_filters , blocks_per_stage )
73- with tf .name_scope ('shake_shake_stage_2' ):
74- x = shake_shake_stage (x , hparams .base_filters * 2 , blocks_per_stage )
75- with tf .name_scope ('shake_shake_stage_3' ):
76- x = shake_shake_stage (x , hparams .base_filters * 4 , blocks_per_stage )
88+ mode = hparams .mode
89+ with tf .variable_scope ('shake_shake_stage_1' ):
90+ x = shake_shake_stage (x , blocks_per_stage , hparams .base_filters , 1 , mode )
91+ with tf .variable_scope ('shake_shake_stage_2' ):
92+ x = shake_shake_stage (x , blocks_per_stage , hparams .base_filters * 2 , 2 ,
93+ mode )
94+ with tf .variable_scope ('shake_shake_stage_3' ):
95+ x = shake_shake_stage (x , blocks_per_stage , hparams .base_filters * 4 , 2 ,
96+ mode )
7797
7898 # For canonical Shake-Shake, we should perform 8x8 average pooling and then
7999 # have a fully-connected layer (which produces the logits for each class).
80100 # Instead, we just use the Xception exit flow in ClassLabelModality.
101+ #
102+ # Also, this model_fn does not return an extra_loss. However, TensorBoard
103+ # reports an exponential moving average for extra_loss, where the initial
104+ # value for the moving average may be a large number, so extra_loss will
105+ # look large at the beginning of training.
81106 return x
82107
108+
83109@registry .register_hparams
84110def shakeshake_cifar10 ():
85111 hparams = common_hparams .basic_params1 ()
86- # This leads to effective batch size 128 when number of GPUs is 2
87- hparams .batch_size = 4096 * 4
112+ # This leads to effective batch size 128 when number of GPUs is 1
113+ hparams .batch_size = 4096 * 8
88114 hparams .hidden_size = 16
89115 hparams .dropout = 0
90116 hparams .label_smoothing = 0.0
@@ -99,7 +125,8 @@ def shakeshake_cifar10():
99125 hparams .learning_rate_warmup_steps = 3000
100126 hparams .initializer = "uniform_unit_scaling"
101127 hparams .initializer_gain = 1.0
102- hparams .weight_decay = 0.1 # Effective value should be ~1e-4
128+ # TODO(rshin): Adjust so that effective value becomes ~1e-4
129+ hparams .weight_decay = 3.0
103130 hparams .optimizer = "Momentum"
104131 hparams .optimizer_momentum_momentum = 0.9
105132 hparams .add_hparam ('base_filters' , 16 )
0 commit comments