1818import tensorflow as tf
1919
2020
21- def define_collect (policy_factory , batch_env , config ):
21+ def define_collect (policy_factory , batch_env , hparams ):
2222 """Collect trajectories."""
23- memory_shape = [config .epoch_length ] + [batch_env .observ .shape .as_list ()[0 ]]
23+ memory_shape = [hparams .epoch_length ] + [batch_env .observ .shape .as_list ()[0 ]]
2424 memories_shapes_and_types = [
2525 # observation
2626 (memory_shape + [batch_env .observ .shape .as_list ()[1 ]], tf .float32 ),
@@ -34,11 +34,11 @@ def define_collect(policy_factory, batch_env, config):
3434 memory = [tf .Variable (tf .zeros (shape , dtype ), trainable = False )
3535 for (shape , dtype ) in memories_shapes_and_types ]
3636 cumulative_rewards = tf .Variable (
37- tf .zeros (config .num_agents , tf .float32 ), trainable = False )
37+ tf .zeros (hparams .num_agents , tf .float32 ), trainable = False )
3838
3939 should_reset_var = tf .Variable (True , trainable = False )
4040 reset_op = tf .cond (should_reset_var ,
41- lambda : batch_env .reset (tf .range (config .num_agents )),
41+ lambda : batch_env .reset (tf .range (hparams .num_agents )),
4242 lambda : 0.0 )
4343 with tf .control_dependencies ([reset_op ]):
4444 reset_once_op = tf .assign (should_reset_var , False )
@@ -59,7 +59,7 @@ def step(index, scores_sum, scores_num):
5959 pdf = policy .prob (action )[0 ]
6060 with tf .control_dependencies (simulate_output ):
6161 reward , done = simulate_output
62- done = tf .reshape (done , (config .num_agents ,))
62+ done = tf .reshape (done , (hparams .num_agents ,))
6363 to_save = [obs_copy , reward , done , action [0 , ...], pdf ,
6464 actor_critic .value [0 ]]
6565 save_ops = [tf .scatter_update (memory_slot , index , value )
@@ -83,7 +83,7 @@ def step(index, scores_sum, scores_num):
8383
8484 init = [tf .constant (0 ), tf .constant (0.0 ), tf .constant (0 )]
8585 index , scores_sum , scores_num = tf .while_loop (
86- lambda c , _1 , _2 : c < config .epoch_length ,
86+ lambda c , _1 , _2 : c < hparams .epoch_length ,
8787 step ,
8888 init ,
8989 parallel_iterations = 1 ,
0 commit comments