Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b66e054

Browse files
T2T Teamcopybara-github
authored andcommitted
Allow action vectors and images as input for training Emily. When using images, action visualization is added to tf summary.
PiperOrigin-RevId: 316106772
1 parent fee90f8 commit b66e054

File tree

3 files changed

+75
-15
lines changed

3 files changed

+75
-15
lines changed

tensor2tensor/models/video/emily.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def construct_model(self, images, actions, rewards):
245245
246246
Args:
247247
images: tensor of ground truth image sequences
248-
actions: NOT used list of action tensors
248+
actions: list of action tensors
249249
rewards: NOT used list of reward tensors
250250
251251
Returns:
@@ -256,7 +256,19 @@ def construct_model(self, images, actions, rewards):
256256
"""
257257
# model does not support action conditioned and reward prediction
258258
fake_reward_prediction = rewards
259-
del actions, rewards
259+
del rewards
260+
action_repeat = self.hparams.action_repeat
261+
action_type = self.hparams.action_type
262+
263+
assert action_type in ["", "image", "vector"], "Invalid action type."
264+
if not action_type:
265+
a_dim = 0
266+
elif action_type == "image":
267+
a_dim = self.hparams.g_dim
268+
else:
269+
assert action_repeat > 0, "Action repeat has to be positive integer."
270+
actions = tf.tile(actions, (1, 1, action_repeat))
271+
a_dim = actions.shape[-1]
260272

261273
z_dim = self.hparams.z_dim
262274
g_dim = self.hparams.g_dim
@@ -277,13 +289,20 @@ def construct_model(self, images, actions, rewards):
277289
tf.logging.info(">>>> Encoding")
278290
# Encoding:
279291
enc_images, enc_skips = [], []
292+
enc_actions = []
280293
images = tf.unstack(images, axis=0)
294+
actions = tf.unstack(actions, axis=0)
281295
for i, image in enumerate(images):
282296
with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
283297
enc, skips = self.encoder(image, g_dim, has_batchnorm=has_batchnorm)
284298
enc = tfl.flatten(enc)
285299
enc_images.append(enc)
286300
enc_skips.append(skips)
301+
if action_type == "image":
302+
enc_action, _ = self.encoder(
303+
actions[i], g_dim, has_batchnorm=has_batchnorm)
304+
enc_action = tfl.flatten(enc_action)
305+
enc_actions.append(enc_action)
287306

288307
tf.logging.info(">>>> Prediction")
289308
# Prediction
@@ -304,6 +323,13 @@ def construct_model(self, images, actions, rewards):
304323
# target encoding
305324
h_target = enc_images[i]
306325

326+
if action_type == "image":
327+
h_current = tf.concat([h_current, enc_actions[i - 1]], axis=1)
328+
h_target = tf.concat([h_target, enc_actions[i]], axis=1)
329+
elif action_type == "vector":
330+
h_current = tf.concat([h_current, actions[i - 1]], axis=1)
331+
h_target = tf.concat([h_target, actions[i]], axis=1)
332+
307333
with tf.variable_scope("prediction", reuse=tf.AUTO_REUSE):
308334
# Prior parameters
309335
if self.hparams.learned_prior:
@@ -315,7 +341,8 @@ def construct_model(self, images, actions, rewards):
315341
logvar_prior = tf.zeros((batch_size, z_dim))
316342

317343
# Only use Posterior if it's training time
318-
if self.is_training or len(gen_images) < context_frames:
344+
if self.hparams.stochastic_model and \
345+
(self.is_training or len(gen_images) < context_frames):
319346
mu_pos, logvar_pos, posterior_states = self.lstm_gaussian(
320347
h_target, posterior_states, rnn_size, z_dim, posterior_rnn_layers,
321348
"posterior")
@@ -338,7 +365,11 @@ def construct_model(self, images, actions, rewards):
338365

339366
with tf.variable_scope("decoding", reuse=tf.AUTO_REUSE):
340367
skip_index = min(context_frames-1, i-1)
341-
h_pred = tf.reshape(h_pred, [batch_size, 1, 1, g_dim])
368+
if action_type == "vector":
369+
h_pred = tf.concat([h_pred, actions[i - 1]], axis=-1)
370+
elif action_type == "image":
371+
h_pred = tf.concat([h_pred, enc_actions[i - 1]], axis=-1)
372+
h_pred = tf.reshape(h_pred, [batch_size, 1, 1, g_dim + a_dim])
342373
if self.hparams.has_skips:
343374
x_pred = self.decoder(
344375
h_pred, color_channels,
@@ -373,22 +404,37 @@ def body(self, features):
373404
input_frames = common_video.swap_time_and_batch_axes(features["inputs"])
374405
target_frames = common_video.swap_time_and_batch_axes(features["targets"])
375406

376-
# Get actions if exist otherwise use zeros
377-
input_actions = self.get_input_if_exists(
378-
features, "input_action", batch_size, hparams.video_num_input_frames)
379-
target_actions = self.get_input_if_exists(
380-
features, "target_action", batch_size, hparams.video_num_target_frames)
381-
382407
# Get rewards if exist otherwise use zeros
383408
input_rewards = self.get_input_if_exists(
384409
features, "input_reward", batch_size, hparams.video_num_input_frames)
385410
target_rewards = self.get_input_if_exists(
386411
features, "target_reward", batch_size, hparams.video_num_target_frames)
387412

388-
all_actions = tf.concat([input_actions, target_actions], axis=0)
389413
all_rewards = tf.concat([input_rewards, target_rewards], axis=0)
390414
all_frames = tf.concat([input_frames, target_frames], axis=0)
391415

416+
# Get actions if exist otherwise use zeros
417+
visualization_kwargs = {}
418+
if hparams.action_type == "image":
419+
input_actions = common_video.swap_time_and_batch_axes(
420+
features["input_action"])
421+
target_actions = common_video.swap_time_and_batch_axes(
422+
features["target_action"])
423+
all_actions = tf.concat([input_actions, target_actions], axis=0)
424+
time, _, h, w, c = all_frames.shape
425+
all_actions = tf.reshape(all_actions, (time, -1, h, w, c))
426+
if self.hparams.action_normalize:
427+
all_actions /= 255.
428+
visualization_kwargs["actions"] = all_actions[:-1]
429+
else:
430+
input_actions = self.get_input_if_exists(features, "input_action",
431+
batch_size,
432+
hparams.video_num_input_frames)
433+
target_actions = self.get_input_if_exists(features, "target_action",
434+
batch_size,
435+
hparams.video_num_target_frames)
436+
all_actions = tf.concat([input_actions, target_actions], axis=0)
437+
392438
# Each image is being used twice, in latent tower and main tower.
393439
# This is to make sure we are using the *same* image for both, ...
394440
# ... given how TF queues work.
@@ -414,7 +460,8 @@ def body(self, features):
414460

415461
# Visualize predictions in Tensorboard
416462
if self.is_training:
417-
self.visualize_predictions(all_frames[1:], gen_images)
463+
self.visualize_predictions(all_frames[1:], gen_images,
464+
**visualization_kwargs)
418465

419466
# Ignore the predictions from the input frames.
420467
# This is NOT the same as original paper/implementation.
@@ -473,4 +520,8 @@ def next_frame_emily():
473520
hparams.add_hparam("predictor_rnn_layers", 2)
474521
hparams.add_hparam("has_skips", True)
475522
hparams.add_hparam("has_batchnorm", True)
523+
# Repeat actions to signify gradients.
524+
# Action type can be '', 'image' or 'vector'.
525+
hparams.add_hparam("action_repeat", 40)
526+
hparams.add_hparam("action_type", "")
476527
return hparams

tensor2tensor/models/video/sv2p.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,14 +509,16 @@ def save_internal_states_ops(self, internal_states):
509509
class NextFrameSv2pLegacy(NextFrameSv2p):
510510
"""Old SV2P code. Only for legacy reasons."""
511511

512-
def visualize_predictions(self, real_frames, gen_frames):
512+
def visualize_predictions(self, real_frames, gen_frames, actions=None):
513+
513514
def concat_on_y_axis(x):
514515
x = tf.unstack(x, axis=1)
515516
x = tf.concat(x, axis=1)
516517
return x
517-
518518
frames_gd = common_video.swap_time_and_batch_axes(real_frames)
519519
frames_pd = common_video.swap_time_and_batch_axes(gen_frames)
520+
if actions is not None:
521+
actions = common_video.swap_time_and_batch_axes(actions)
520522

521523
if self.is_per_pixel_softmax:
522524
frames_pd_shape = common_layers.shape_list(frames_pd)
@@ -526,7 +528,13 @@ def concat_on_y_axis(x):
526528

527529
frames_gd = concat_on_y_axis(frames_gd)
528530
frames_pd = concat_on_y_axis(frames_pd)
529-
side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
531+
if actions is not None:
532+
actions = tf.clip_by_value(actions, 0, 1)
533+
summary("action_vid", tf.cast(actions * 255, tf.uint8))
534+
actions = concat_on_y_axis(actions)
535+
side_by_side_video = tf.concat([frames_gd, frames_pd, actions], axis=2)
536+
else:
537+
side_by_side_video = tf.concat([frames_gd, frames_pd], axis=2)
530538
tf.summary.image("full_video", side_by_side_video)
531539

532540
def get_input_if_exists(self, features, key, batch_size, num_frames):

tensor2tensor/models/video/sv2p_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def next_frame_sv2p():
5757
hparams.add_hparam("upsample_method", "conv2d_transpose")
5858
hparams.add_hparam("reward_model", "basic")
5959
hparams.add_hparam("visualize_logits_histogram", True)
60+
hparams.add_hparam("action_normalize", False)
6061
return hparams
6162

6263

0 commit comments

Comments
 (0)