Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e97ce1b

Browse files
author
Błażej O
committed
Improving readability.
1 parent 3c2af07 commit e97ce1b

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

tensor2tensor/rl/collect.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import tensorflow as tf
1919

2020

21-
def define_collect(policy_factory, batch_env, hparams, eval):
21+
def define_collect(policy_factory, batch_env, hparams, eval_phase):
2222
"""Collect trajectories."""
23-
eval = tf.convert_to_tensor(eval)
23+
eval_phase = tf.convert_to_tensor(eval_phase)
2424
memory_shape = [hparams.epoch_length] + [batch_env.observ.shape.as_list()[0]]
2525
memories_shapes_and_types = [
2626
# observation
@@ -39,7 +39,7 @@ def define_collect(policy_factory, batch_env, hparams, eval):
3939

4040
should_reset_var = tf.Variable(True, trainable=False)
4141
reset_op = tf.cond(
42-
tf.logical_or(should_reset_var, eval),
42+
tf.logical_or(should_reset_var, eval_phase),
4343
lambda: tf.group(batch_env.reset(tf.range(len(batch_env))),
4444
tf.assign(cumulative_rewards, tf.zeros(len(batch_env)))),
4545
lambda: tf.no_op())
@@ -57,7 +57,7 @@ def step(index, scores_sum, scores_num):
5757
obs_copy = batch_env.observ + 0
5858
actor_critic = policy_factory(tf.expand_dims(obs_copy, 0))
5959
policy = actor_critic.policy
60-
action = tf.cond(eval,
60+
action = tf.cond(eval_phase,
6161
policy.mode,
6262
policy.sample)
6363
postprocessed_action = actor_critic.action_postprocessing(action)
@@ -88,7 +88,7 @@ def step(index, scores_sum, scores_num):
8888
scores_num + scores_num_delta]
8989

9090
def stop_condition(i, _, resets):
91-
return tf.cond(eval,
91+
return tf.cond(eval_phase,
9292
lambda: resets < hparams.num_eval_agents,
9393
lambda: i < hparams.epoch_length)
9494

@@ -105,9 +105,10 @@ def stop_condition(i, _, resets):
105105
printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ")
106106
with tf.control_dependencies([index, printing]):
107107
memory = [tf.identity(mem) for mem in memory]
108-
mean_score_summary = tf.cond(tf.greater(scores_num, 0),
109-
lambda: tf.summary.scalar("mean_score_this_iter", mean_score),
110-
str)
108+
mean_score_summary = tf.cond(
109+
tf.greater(scores_num, 0),
110+
lambda: tf.summary.scalar("mean_score_this_iter", mean_score),
111+
str)
111112
summaries = tf.summary.merge(
112113
[mean_score_summary,
113114
tf.summary.scalar("episodes_finished_this_iter", scores_num)])

tensor2tensor/rl/rl_trainer_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def define_train(hparams, environment_name, event_dir):
4848

4949
with tf.variable_scope("train"):
5050
memory, collect_summary = collect.define_collect(
51-
policy_factory, batch_env, hparams, eval=False)
51+
policy_factory, batch_env, hparams, eval_phase=False)
5252
ppo_summary = ppo.define_ppo_epoch(memory, policy_factory, hparams)
5353
summary = tf.summary.merge([collect_summary, ppo_summary])
5454

@@ -59,7 +59,7 @@ def define_train(hparams, environment_name, event_dir):
5959
_, eval_summary = collect.define_collect(
6060
policy_factory,
6161
utils.define_batch_env(eval_env_lambda, hparams.num_eval_agents, xvfb=True),
62-
hparams, eval=True)
62+
hparams, eval_phase=True)
6363
return summary, eval_summary
6464

6565

0 commit comments

Comments
 (0)