@@ -40,26 +40,33 @@ def body(self, features):
4040 # Concat frames and down-stride.
4141 cur_frame = tf .to_float (features ["inputs" ])
4242 prev_frame = tf .to_float (features ["inputs_prev" ])
43- frames = tf .concat ([cur_frame , prev_frame ], axis = - 1 )
44- x = tf .layers .conv2d (frames , filters , kernel2 , activation = tf .nn .relu ,
45- strides = (2 , 2 ), padding = "SAME" )
43+ x = tf .concat ([cur_frame , prev_frame ], axis = - 1 )
44+ for _ in xrange (hparams .num_compress_steps ):
45+ x = tf .layers .conv2d (x , filters , kernel2 , activation = common_layers .belu ,
46+ strides = (2 , 2 ), padding = "SAME" )
47+ x = common_layers .layer_norm (x )
48+ filters *= 2
4649 # Add embedded action.
47- action = tf .reshape (features ["action" ], [- 1 , 1 , 1 , filters ])
48- x = tf .concat ([x , action + tf .zeros_like (x )], axis = - 1 )
50+ action = tf .reshape (features ["action" ], [- 1 , 1 , 1 , hparams .hidden_size ])
51+ zeros = tf .zeros (common_layers .shape_list (x )[:- 1 ] + [hparams .hidden_size ])
52+ x = tf .concat ([x , action + zeros ], axis = - 1 )
4953
5054 # Run a stack of convolutions.
5155 for i in xrange (hparams .num_hidden_layers ):
5256 with tf .variable_scope ("layer%d" % i ):
53- y = tf .layers .conv2d (x , 2 * filters , kernel1 , activation = tf . nn . relu ,
57+ y = tf .layers .conv2d (x , filters , kernel1 , activation = common_layers . belu ,
5458 strides = (1 , 1 ), padding = "SAME" )
5559 if i == 0 :
5660 x = y
5761 else :
5862 x = common_layers .layer_norm (x + y )
5963 # Up-convolve.
60- x = tf .layers .conv2d_transpose (
61- x , filters , kernel2 , activation = tf .nn .relu ,
62- strides = (2 , 2 ), padding = "SAME" )
64+ for _ in xrange (hparams .num_compress_steps ):
65+ filters //= 2
66+ x = tf .layers .conv2d_transpose (
67+ x , filters , kernel2 , activation = common_layers .belu ,
68+ strides = (2 , 2 ), padding = "SAME" )
69+ x = common_layers .layer_norm (x )
6370
6471 # Reward prediction.
6572 reward_pred_h1 = tf .reduce_mean (x , axis = [1 , 2 ], keep_dims = True )
@@ -78,7 +85,7 @@ def basic_conv():
7885 hparams = common_hparams .basic_params1 ()
7986 hparams .hidden_size = 64
8087 hparams .batch_size = 8
81- hparams .num_hidden_layers = 2
88+ hparams .num_hidden_layers = 3
8289 hparams .optimizer = "Adam"
8390 hparams .learning_rate_constant = 0.0002
8491 hparams .learning_rate_warmup_steps = 500
@@ -87,6 +94,7 @@ def basic_conv():
8794 hparams .initializer = "uniform_unit_scaling"
8895 hparams .initializer_gain = 1.0
8996 hparams .weight_decay = 0.0
97+ hparams .add_hparam ("num_compress_steps" , 2 )
9098 return hparams
9199
92100
0 commit comments