Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 1bb6fb4

Browse files
author
Błażej O
committed
Renaming config for hparams.
1 parent 5eb0f96 commit 1bb6fb4

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

tensor2tensor/rl/collect.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import tensorflow as tf
1919

2020

21-
def define_collect(policy_factory, batch_env, config):
21+
def define_collect(policy_factory, batch_env, hparams):
2222
"""Collect trajectories."""
23-
memory_shape = [config.epoch_length] + [batch_env.observ.shape.as_list()[0]]
23+
memory_shape = [hparams.epoch_length] + [batch_env.observ.shape.as_list()[0]]
2424
memories_shapes_and_types = [
2525
# observation
2626
(memory_shape + [batch_env.observ.shape.as_list()[1]], tf.float32),
@@ -34,11 +34,11 @@ def define_collect(policy_factory, batch_env, config):
3434
memory = [tf.Variable(tf.zeros(shape, dtype), trainable=False)
3535
for (shape, dtype) in memories_shapes_and_types]
3636
cumulative_rewards = tf.Variable(
37-
tf.zeros(config.num_agents, tf.float32), trainable=False)
37+
tf.zeros(hparams.num_agents, tf.float32), trainable=False)
3838

3939
should_reset_var = tf.Variable(True, trainable=False)
4040
reset_op = tf.cond(should_reset_var,
41-
lambda: batch_env.reset(tf.range(config.num_agents)),
41+
lambda: batch_env.reset(tf.range(hparams.num_agents)),
4242
lambda: 0.0)
4343
with tf.control_dependencies([reset_op]):
4444
reset_once_op = tf.assign(should_reset_var, False)
@@ -59,7 +59,7 @@ def step(index, scores_sum, scores_num):
5959
pdf = policy.prob(action)[0]
6060
with tf.control_dependencies(simulate_output):
6161
reward, done = simulate_output
62-
done = tf.reshape(done, (config.num_agents,))
62+
done = tf.reshape(done, (hparams.num_agents,))
6363
to_save = [obs_copy, reward, done, action[0, ...], pdf,
6464
actor_critic.value[0]]
6565
save_ops = [tf.scatter_update(memory_slot, index, value)
@@ -83,7 +83,7 @@ def step(index, scores_sum, scores_num):
8383

8484
init = [tf.constant(0), tf.constant(0.0), tf.constant(0)]
8585
index, scores_sum, scores_num = tf.while_loop(
86-
lambda c, _1, _2: c < config.epoch_length,
86+
lambda c, _1, _2: c < hparams.epoch_length,
8787
step,
8888
init,
8989
parallel_iterations=1,

0 commit comments

Comments
 (0)