1818import tensorflow as tf
1919
2020
21- def define_collect (policy_factory , batch_env , hparams , eval ):
21+ def define_collect (policy_factory , batch_env , hparams , eval_phase ):
2222 """Collect trajectories."""
23- eval = tf .convert_to_tensor (eval )
23+ eval_phase = tf .convert_to_tensor (eval_phase )
2424 memory_shape = [hparams .epoch_length ] + [batch_env .observ .shape .as_list ()[0 ]]
2525 memories_shapes_and_types = [
2626 # observation
@@ -39,7 +39,7 @@ def define_collect(policy_factory, batch_env, hparams, eval):
3939
4040 should_reset_var = tf .Variable (True , trainable = False )
4141 reset_op = tf .cond (
42- tf .logical_or (should_reset_var , eval ),
42+ tf .logical_or (should_reset_var , eval_phase ),
4343 lambda : tf .group (batch_env .reset (tf .range (len (batch_env ))),
4444 tf .assign (cumulative_rewards , tf .zeros (len (batch_env )))),
4545 lambda : tf .no_op ())
@@ -57,7 +57,7 @@ def step(index, scores_sum, scores_num):
5757 obs_copy = batch_env .observ + 0
5858 actor_critic = policy_factory (tf .expand_dims (obs_copy , 0 ))
5959 policy = actor_critic .policy
60- action = tf .cond (eval ,
60+ action = tf .cond (eval_phase ,
6161 policy .mode ,
6262 policy .sample )
6363 postprocessed_action = actor_critic .action_postprocessing (action )
@@ -88,7 +88,7 @@ def step(index, scores_sum, scores_num):
8888 scores_num + scores_num_delta ]
8989
9090 def stop_condition (i , _ , resets ):
91- return tf .cond (eval ,
91+ return tf .cond (eval_phase ,
9292 lambda : resets < hparams .num_eval_agents ,
9393 lambda : i < hparams .epoch_length )
9494
@@ -105,9 +105,10 @@ def stop_condition(i, _, resets):
105105 printing = tf .Print (0 , [mean_score , scores_sum , scores_num ], "mean_score: " )
106106 with tf .control_dependencies ([index , printing ]):
107107 memory = [tf .identity (mem ) for mem in memory ]
108- mean_score_summary = tf .cond (tf .greater (scores_num , 0 ),
109- lambda : tf .summary .scalar ("mean_score_this_iter" , mean_score ),
110- str )
108+ mean_score_summary = tf .cond (
109+ tf .greater (scores_num , 0 ),
110+ lambda : tf .summary .scalar ("mean_score_this_iter" , mean_score ),
111+ str )
111112 summaries = tf .summary .merge (
112113 [mean_score_summary ,
113114 tf .summary .scalar ("episodes_finished_this_iter" , scores_num )])
0 commit comments