Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8bdaeb2

Browse files
Lukasz KaiserCopybara-Service
authored andcommitted
RL bugfixes.
PiperOrigin-RevId: 212026070
1 parent 2506671 commit 8bdaeb2

File tree

7 files changed

+87
-76
lines changed

7 files changed

+87
-76
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,16 @@ def expand_squeeze_to_nd(x, n, squeeze_dim=2, expand_dim=-1):
253253

254254

255255
def standardize_images(x):
256-
"""Image standardization on batches."""
256+
"""Image standardization on batches and videos."""
257257
with tf.name_scope("standardize_images", [x]):
258-
x = tf.to_float(x)
258+
x_shape = shape_list(x)
259+
x = tf.to_float(tf.reshape(x, [-1] + x_shape[-3:]))
259260
x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True)
260261
x_variance = tf.reduce_mean(
261262
tf.square(x - x_mean), axis=[1, 2, 3], keepdims=True)
262-
x_shape = shape_list(x)
263-
num_pixels = tf.to_float(x_shape[1] * x_shape[2] * x_shape[3])
263+
num_pixels = tf.to_float(x_shape[-1] * x_shape[-2] * x_shape[-3])
264264
x = (x - x_mean) / tf.maximum(tf.sqrt(x_variance), tf.rsqrt(num_pixels))
265-
return x
265+
return tf.reshape(x, x_shape)
266266

267267

268268
def flatten4d3d(x):

tensor2tensor/layers/modalities.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -535,17 +535,8 @@ def bottom(self, x):
535535
inputs = x
536536
with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
537537
common_layers.summarize_video(inputs, "inputs")
538-
inputs_shape = common_layers.shape_list(inputs)
539-
# Standardize frames.
540-
inputs = tf.reshape(inputs, [-1] + inputs_shape[2:])
541538
inputs = common_layers.standardize_images(inputs)
542-
inputs = tf.reshape(inputs, inputs_shape)
543-
# Concatenate the time dimension on channels for image models to work.
544-
transposed = tf.transpose(inputs, [0, 2, 3, 1, 4])
545-
return tf.reshape(transposed, [
546-
inputs_shape[0], inputs_shape[2], inputs_shape[3],
547-
inputs_shape[1] * inputs_shape[4]
548-
])
539+
return common_layers.time_to_channels(inputs)
549540

550541
def targets_bottom(self, x, summary_prefix="targets_bottom"): # pylint: disable=arguments-differ
551542
inputs = x
@@ -573,10 +564,13 @@ def top(self, body_output, targets):
573564
num_frames = common_layers.shape_list(targets)[1]
574565
body_output_shape = common_layers.shape_list(body_output)
575566
# We assume the body output is of this shape and layout.
567+
# Note: if you tf.concat([frames], axis=-1) at the end of your model,
568+
# then you need to reshape to [..., num_frames, depth] like below, not
569+
# into [..., depth, num_frames] due to memory layout of concat/reshape.
576570
reshape_shape = body_output_shape[:-1] + [
577-
num_channels, self.top_dimensionality, num_frames]
571+
num_channels, num_frames, self.top_dimensionality]
578572
res = tf.reshape(body_output, reshape_shape)
579-
res = tf.transpose(res, [0, 5, 1, 2, 3, 4])
573+
res = tf.transpose(res, [0, 4, 1, 2, 3, 5])
580574
res_shape = common_layers.shape_list(res)
581575
res_argmax = tf.argmax(tf.reshape(res, [-1, res_shape[-1]]), axis=-1)
582576
res_argmax = tf.reshape(res_argmax, res_shape[:-1])

tensor2tensor/models/research/autoencoders.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ def encoder(self, x):
594594
activation=common_layers.belu,
595595
name="strided")
596596
y = x
597+
y = tf.nn.dropout(y, 1.0 - hparams.residual_dropout)
597598
for r in range(hparams.num_residual_layers):
598599
residual_filters = filters
599600
if r < hparams.num_residual_layers - 1:
@@ -606,7 +607,7 @@ def encoder(self, x):
606607
padding="SAME",
607608
activation=common_layers.belu,
608609
name="residual_%d" % r)
609-
x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout)
610+
x += y
610611
x = common_layers.layer_norm(x, name="ln")
611612
return x, layers
612613

tensor2tensor/models/video/basic_deterministic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def body_single(self, features):
8585
# Run a stack of convolutions.
8686
for i in range(hparams.num_hidden_layers):
8787
with tf.variable_scope("layer%d" % i):
88-
y = tf.layers.conv2d(x, filters, kernel1, activation=common_layers.belu,
88+
y = tf.nn.dropout(x, 1.0 - hparams.dropout)
89+
y = tf.layers.conv2d(y, filters, kernel1, activation=common_layers.belu,
8990
strides=(1, 1), padding="SAME")
90-
y = tf.nn.dropout(y, 1.0 - hparams.dropout)
9191
if i == 0:
9292
x = y
9393
else:
@@ -172,6 +172,7 @@ def body(self, features):
172172
sampled_frame = tf.reshape(
173173
res_frames[-1], shape[:-1] + [hparams.problem.num_channels, 256])
174174
sampled_frame = tf.to_float(tf.argmax(sampled_frame, axis=-1))
175+
sampled_frame = common_layers.standardize_images(sampled_frame)
175176
if is_predicting:
176177
all_frames[i + hparams.video_num_input_frames] = sampled_frame
177178

tensor2tensor/models/video/basic_deterministic_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def next_frame_sampling():
6565
"""Basic conv model with scheduled sampling."""
6666
hparams = next_frame_basic_deterministic()
6767
hparams.video_num_target_frames = 2
68-
hparams.scheduled_sampling_warmup_steps = 30000
69-
hparams.scheduled_sampling_prob = 0.1
68+
hparams.scheduled_sampling_warmup_steps = 50000
69+
hparams.scheduled_sampling_prob = 0.5
7070
return hparams
7171

7272

tensor2tensor/rl/rl_trainer_lib.py

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -41,61 +41,62 @@ def define_train(hparams):
4141

4242

4343
def train(hparams, event_dir=None, model_dir=None,
44-
restore_agent=True, epoch=0):
44+
restore_agent=True, epoch=0, name_scope="rl_train"):
4545
"""Train."""
4646
with tf.Graph().as_default():
47-
train_summary_op, _, initialization = define_train(hparams)
48-
if event_dir:
49-
summary_writer = tf.summary.FileWriter(
50-
event_dir, graph=tf.get_default_graph(), flush_secs=60)
51-
else:
52-
summary_writer = None
47+
with tf.name_scope(name_scope):
48+
train_summary_op, _, initialization = define_train(hparams)
49+
if event_dir:
50+
summary_writer = tf.summary.FileWriter(
51+
event_dir, graph=tf.get_default_graph(), flush_secs=60)
52+
else:
53+
summary_writer = None
5354

54-
if model_dir:
55-
model_saver = tf.train.Saver(
56-
tf.global_variables(".*network_parameters.*"))
57-
else:
58-
model_saver = None
55+
if model_dir:
56+
model_saver = tf.train.Saver(
57+
tf.global_variables(".*network_parameters.*"))
58+
else:
59+
model_saver = None
5960

60-
# TODO(piotrmilos): This should be refactored, possibly with
61-
# handlers for each type of env
62-
if hparams.environment_spec.simulated_env:
63-
env_model_loader = tf.train.Saver(
64-
tf.global_variables("next_frame*"))
65-
else:
66-
env_model_loader = None
61+
# TODO(piotrmilos): This should be refactored, possibly with
62+
# handlers for each type of env
63+
if hparams.environment_spec.simulated_env:
64+
env_model_loader = tf.train.Saver(
65+
tf.global_variables("next_frame*"))
66+
else:
67+
env_model_loader = None
6768

68-
with tf.Session() as sess:
69-
sess.run(tf.global_variables_initializer())
70-
initialization(sess)
71-
if env_model_loader:
72-
trainer_lib.restore_checkpoint(
73-
hparams.world_model_dir, env_model_loader, sess,
74-
must_restore=True)
75-
start_step = 0
76-
if model_saver and restore_agent:
77-
start_step = trainer_lib.restore_checkpoint(
78-
model_dir, model_saver, sess)
69+
with tf.Session() as sess:
70+
sess.run(tf.global_variables_initializer())
71+
initialization(sess)
72+
if env_model_loader:
73+
trainer_lib.restore_checkpoint(
74+
hparams.world_model_dir, env_model_loader, sess,
75+
must_restore=True)
76+
start_step = 0
77+
if model_saver and restore_agent:
78+
start_step = trainer_lib.restore_checkpoint(
79+
model_dir, model_saver, sess)
7980

80-
# Fail-friendly, don't train if already trained for this epoch
81-
if start_step >= ((hparams.epochs_num * (epoch + 1))):
82-
tf.logging.info("Skipping PPO training for epoch %d as train steps "
83-
"(%d) already reached", epoch, start_step)
84-
return
81+
# Fail-friendly, don't train if already trained for this epoch
82+
if start_step >= ((hparams.epochs_num * (epoch + 1))):
83+
tf.logging.info("Skipping PPO training for epoch %d as train steps "
84+
"(%d) already reached", epoch, start_step)
85+
return
8586

86-
for epoch_index in range(hparams.epochs_num):
87-
summary = sess.run(train_summary_op)
88-
if summary_writer:
89-
summary_writer.add_summary(summary, epoch_index)
90-
if (hparams.eval_every_epochs and
91-
epoch_index % hparams.eval_every_epochs == 0):
92-
if summary_writer and summary:
87+
for epoch_index in range(hparams.epochs_num):
88+
summary = sess.run(train_summary_op)
89+
if summary_writer:
9390
summary_writer.add_summary(summary, epoch_index)
94-
else:
95-
tf.logging.info("Eval summary not saved")
96-
if (model_saver and hparams.save_models_every_epochs and
97-
(epoch_index % hparams.save_models_every_epochs == 0 or
98-
(epoch_index + 1) == hparams.epochs_num)):
99-
ckpt_path = os.path.join(
100-
model_dir, "model.ckpt-{}".format(epoch_index + 1 + start_step))
101-
model_saver.save(sess, ckpt_path)
91+
if (hparams.eval_every_epochs and
92+
epoch_index % hparams.eval_every_epochs == 0):
93+
if summary_writer and summary:
94+
summary_writer.add_summary(summary, epoch_index)
95+
else:
96+
tf.logging.info("Eval summary not saved")
97+
if (model_saver and hparams.save_models_every_epochs and
98+
(epoch_index % hparams.save_models_every_epochs == 0 or
99+
(epoch_index + 1) == hparams.epochs_num)):
100+
ckpt_path = os.path.join(
101+
model_dir, "model.ckpt-{}".format(epoch_index + 1 + start_step))
102+
model_saver.save(sess, ckpt_path)

tensor2tensor/rl/trainer_model_based.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def train_agent(problem_name, agent_model_dir,
183183
"output_dir": world_model_dir,
184184
"data_dir": epoch_data_dir,
185185
}):
186-
rl_trainer_lib.train(ppo_hparams, event_dir, agent_model_dir, epoch=epoch)
186+
rl_trainer_lib.train(ppo_hparams, event_dir, agent_model_dir, epoch=epoch,
187+
name_scope="ppo_sim")
187188

188189

189190
def train_agent_real_env(
@@ -218,7 +219,8 @@ def train_agent_real_env(
218219
"data_dir": epoch_data_dir,
219220
}):
220221
# epoch = 10**20 is a hackish way to avoid skiping training
221-
rl_trainer_lib.train(ppo_hparams, event_dir, agent_model_dir, epoch=10**20)
222+
rl_trainer_lib.train(ppo_hparams, event_dir, agent_model_dir, epoch=10**20,
223+
name_scope="ppo_real")
222224

223225

224226
def evaluate_world_model(simulated_problem_name, problem_name, hparams,
@@ -266,6 +268,7 @@ def train_world_model(problem_name, data_dir, output_dir, hparams, epoch):
266268
"hparams_set": hparams.generative_model_params,
267269
"hparams": "learning_rate_constant=%.6f" % learning_rate,
268270
"eval_steps": 100,
271+
"local_eval_frequency": 2000,
269272
"train_steps": train_steps,
270273
}):
271274
t2t_trainer.main([])
@@ -519,6 +522,7 @@ def training_loop(hparams, output_dir, report_fn=None, report_metric=None):
519522
mean_reward_summary.value[0].simple_value = mean_reward
520523
eval_metrics_writer.add_summary(model_reward_accuracy_summary, epoch)
521524
eval_metrics_writer.add_summary(mean_reward_summary, epoch)
525+
eval_metrics_writer.flush()
522526

523527
# Report metrics
524528
eval_metrics = {"model_reward_accuracy": model_reward_accuracy,
@@ -599,10 +603,12 @@ def rl_modelrl_base():
599603

600604
@registry.register_hparams
601605
def rl_modelrl_base_quick():
602-
"""Base setting with only 2 epochs and 500 PPO steps per epoch."""
606+
"""Base setting but quicker with only 2 epochs."""
603607
hparams = rl_modelrl_base()
604608
hparams.epochs = 2
605-
hparams.ppo_epochs_num = 500
609+
hparams.ppo_epochs_num = 1000
610+
hparams.ppo_epoch_length = 50
611+
hparams.real_ppo_epochs_num = 10
606612
return hparams
607613

608614

@@ -615,6 +621,14 @@ def rl_modelrl_base_quick_sd():
615621
return hparams
616622

617623

624+
@registry.register_hparams
625+
def rl_modelrl_base_quick_sm():
626+
"""Quick setting with sampling."""
627+
hparams = rl_modelrl_base_quick()
628+
hparams.generative_model_params = "next_frame_sampling"
629+
return hparams
630+
631+
618632
@registry.register_hparams
619633
def rl_modelrl_base_stochastic():
620634
"""Base setting with a stochastic next-frame model."""

0 commit comments

Comments
 (0)