@@ -245,7 +245,7 @@ def construct_model(self, images, actions, rewards):
245245
246246 Args:
247247 images: tensor of ground truth image sequences
248- actions: NOT used list of action tensors
248+ actions: list of action tensors
249249 rewards: NOT used list of reward tensors
250250
251251 Returns:
@@ -256,7 +256,19 @@ def construct_model(self, images, actions, rewards):
256256 """
257257 # model does not support action conditioned and reward prediction
258258 fake_reward_prediction = rewards
259- del actions , rewards
259+ del rewards
260+ action_repeat = self .hparams .action_repeat
261+ action_type = self .hparams .action_type
262+
263+ assert action_type in ["" , "image" , "vector" ], "Invalid action type."
264+ if not action_type :
265+ a_dim = 0
266+ elif action_type == "image" :
267+ a_dim = self .hparams .g_dim
268+ else :
269+ assert action_repeat > 0 , "Action repeat has to be positive integer."
270+ actions = tf .tile (actions , (1 , 1 , action_repeat ))
271+ a_dim = actions .shape [- 1 ]
260272
261273 z_dim = self .hparams .z_dim
262274 g_dim = self .hparams .g_dim
@@ -277,13 +289,20 @@ def construct_model(self, images, actions, rewards):
277289 tf .logging .info (">>>> Encoding" )
278290 # Encoding:
279291 enc_images , enc_skips = [], []
292+ enc_actions = []
280293 images = tf .unstack (images , axis = 0 )
294+ actions = tf .unstack (actions , axis = 0 )
281295 for i , image in enumerate (images ):
282296 with tf .variable_scope ("encoder" , reuse = tf .AUTO_REUSE ):
283297 enc , skips = self .encoder (image , g_dim , has_batchnorm = has_batchnorm )
284298 enc = tfl .flatten (enc )
285299 enc_images .append (enc )
286300 enc_skips .append (skips )
301+ if action_type == "image" :
302+ enc_action , _ = self .encoder (
303+ actions [i ], g_dim , has_batchnorm = has_batchnorm )
304+ enc_action = tfl .flatten (enc_action )
305+ enc_actions .append (enc_action )
287306
288307 tf .logging .info (">>>> Prediction" )
289308 # Prediction
@@ -304,6 +323,13 @@ def construct_model(self, images, actions, rewards):
304323 # target encoding
305324 h_target = enc_images [i ]
306325
326+ if action_type == "image" :
327+ h_current = tf .concat ([h_current , enc_actions [i - 1 ]], axis = 1 )
328+ h_target = tf .concat ([h_target , enc_actions [i ]], axis = 1 )
329+ elif action_type == "vector" :
330+ h_current = tf .concat ([h_current , actions [i - 1 ]], axis = 1 )
331+ h_target = tf .concat ([h_target , actions [i ]], axis = 1 )
332+
307333 with tf .variable_scope ("prediction" , reuse = tf .AUTO_REUSE ):
308334 # Prior parameters
309335 if self .hparams .learned_prior :
@@ -315,7 +341,8 @@ def construct_model(self, images, actions, rewards):
315341 logvar_prior = tf .zeros ((batch_size , z_dim ))
316342
317343 # Only use Posterior if it's training time
318- if self .is_training or len (gen_images ) < context_frames :
344+ if self .hparams .stochastic_model and \
345+ (self .is_training or len (gen_images ) < context_frames ):
319346 mu_pos , logvar_pos , posterior_states = self .lstm_gaussian (
320347 h_target , posterior_states , rnn_size , z_dim , posterior_rnn_layers ,
321348 "posterior" )
@@ -338,7 +365,11 @@ def construct_model(self, images, actions, rewards):
338365
339366 with tf .variable_scope ("decoding" , reuse = tf .AUTO_REUSE ):
340367 skip_index = min (context_frames - 1 , i - 1 )
341- h_pred = tf .reshape (h_pred , [batch_size , 1 , 1 , g_dim ])
368+ if action_type == "vector" :
369+ h_pred = tf .concat ([h_pred , actions [i - 1 ]], axis = - 1 )
370+ elif action_type == "image" :
371+ h_pred = tf .concat ([h_pred , enc_actions [i - 1 ]], axis = - 1 )
372+ h_pred = tf .reshape (h_pred , [batch_size , 1 , 1 , g_dim + a_dim ])
342373 if self .hparams .has_skips :
343374 x_pred = self .decoder (
344375 h_pred , color_channels ,
@@ -373,22 +404,37 @@ def body(self, features):
373404 input_frames = common_video .swap_time_and_batch_axes (features ["inputs" ])
374405 target_frames = common_video .swap_time_and_batch_axes (features ["targets" ])
375406
376- # Get actions if exist otherwise use zeros
377- input_actions = self .get_input_if_exists (
378- features , "input_action" , batch_size , hparams .video_num_input_frames )
379- target_actions = self .get_input_if_exists (
380- features , "target_action" , batch_size , hparams .video_num_target_frames )
381-
382407 # Get rewards if exist otherwise use zeros
383408 input_rewards = self .get_input_if_exists (
384409 features , "input_reward" , batch_size , hparams .video_num_input_frames )
385410 target_rewards = self .get_input_if_exists (
386411 features , "target_reward" , batch_size , hparams .video_num_target_frames )
387412
388- all_actions = tf .concat ([input_actions , target_actions ], axis = 0 )
389413 all_rewards = tf .concat ([input_rewards , target_rewards ], axis = 0 )
390414 all_frames = tf .concat ([input_frames , target_frames ], axis = 0 )
391415
416+ # Get actions if exist otherwise use zeros
417+ visualization_kwargs = {}
418+ if hparams .action_type == "image" :
419+ input_actions = common_video .swap_time_and_batch_axes (
420+ features ["input_action" ])
421+ target_actions = common_video .swap_time_and_batch_axes (
422+ features ["target_action" ])
423+ all_actions = tf .concat ([input_actions , target_actions ], axis = 0 )
424+ time , _ , h , w , c = all_frames .shape
425+ all_actions = tf .reshape (all_actions , (time , - 1 , h , w , c ))
426+ if self .hparams .action_normalize :
427+ all_actions /= 255.
428+ visualization_kwargs ["actions" ] = all_actions [:- 1 ]
429+ else :
430+ input_actions = self .get_input_if_exists (features , "input_action" ,
431+ batch_size ,
432+ hparams .video_num_input_frames )
433+ target_actions = self .get_input_if_exists (features , "target_action" ,
434+ batch_size ,
435+ hparams .video_num_target_frames )
436+ all_actions = tf .concat ([input_actions , target_actions ], axis = 0 )
437+
392438 # Each image is being used twice, in latent tower and main tower.
393439 # This is to make sure we are using the *same* image for both, ...
394440 # ... given how TF queues work.
@@ -414,7 +460,8 @@ def body(self, features):
414460
415461 # Visualize predictions in Tensorboard
416462 if self .is_training :
417- self .visualize_predictions (all_frames [1 :], gen_images )
463+ self .visualize_predictions (all_frames [1 :], gen_images ,
464+ ** visualization_kwargs )
418465
419466 # Ignore the predictions from the input frames.
420467 # This is NOT the same as original paper/implementation.
@@ -473,4 +520,8 @@ def next_frame_emily():
473520 hparams .add_hparam ("predictor_rnn_layers" , 2 )
474521 hparams .add_hparam ("has_skips" , True )
475522 hparams .add_hparam ("has_batchnorm" , True )
523+ # Repeat actions to signify gradients.
524+ # Action type can be '', 'image' or 'vector'.
525+ hparams .add_hparam ("action_repeat" , 40 )
526+ hparams .add_hparam ("action_type" , "" )
476527 return hparams
0 commit comments