@@ -41,61 +41,62 @@ def define_train(hparams):
4141
4242
4343def train (hparams , event_dir = None , model_dir = None ,
44- restore_agent = True , epoch = 0 ):
44+ restore_agent = True , epoch = 0 , name_scope = "rl_train" ):
4545 """Train."""
4646 with tf .Graph ().as_default ():
47- train_summary_op , _ , initialization = define_train (hparams )
48- if event_dir :
49- summary_writer = tf .summary .FileWriter (
50- event_dir , graph = tf .get_default_graph (), flush_secs = 60 )
51- else :
52- summary_writer = None
47+ with tf .name_scope (name_scope ):
48+ train_summary_op , _ , initialization = define_train (hparams )
49+ if event_dir :
50+ summary_writer = tf .summary .FileWriter (
51+ event_dir , graph = tf .get_default_graph (), flush_secs = 60 )
52+ else :
53+ summary_writer = None
5354
54- if model_dir :
55- model_saver = tf .train .Saver (
56- tf .global_variables (".*network_parameters.*" ))
57- else :
58- model_saver = None
55+ if model_dir :
56+ model_saver = tf .train .Saver (
57+ tf .global_variables (".*network_parameters.*" ))
58+ else :
59+ model_saver = None
5960
60- # TODO(piotrmilos): This should be refactored, possibly with
61- # handlers for each type of env
62- if hparams .environment_spec .simulated_env :
63- env_model_loader = tf .train .Saver (
64- tf .global_variables ("next_frame*" ))
65- else :
66- env_model_loader = None
61+ # TODO(piotrmilos): This should be refactored, possibly with
62+ # handlers for each type of env
63+ if hparams .environment_spec .simulated_env :
64+ env_model_loader = tf .train .Saver (
65+ tf .global_variables ("next_frame*" ))
66+ else :
67+ env_model_loader = None
6768
68- with tf .Session () as sess :
69- sess .run (tf .global_variables_initializer ())
70- initialization (sess )
71- if env_model_loader :
72- trainer_lib .restore_checkpoint (
73- hparams .world_model_dir , env_model_loader , sess ,
74- must_restore = True )
75- start_step = 0
76- if model_saver and restore_agent :
77- start_step = trainer_lib .restore_checkpoint (
78- model_dir , model_saver , sess )
69+ with tf .Session () as sess :
70+ sess .run (tf .global_variables_initializer ())
71+ initialization (sess )
72+ if env_model_loader :
73+ trainer_lib .restore_checkpoint (
74+ hparams .world_model_dir , env_model_loader , sess ,
75+ must_restore = True )
76+ start_step = 0
77+ if model_saver and restore_agent :
78+ start_step = trainer_lib .restore_checkpoint (
79+ model_dir , model_saver , sess )
7980
80- # Fail-friendly, don't train if already trained for this epoch
81- if start_step >= ((hparams .epochs_num * (epoch + 1 ))):
82- tf .logging .info ("Skipping PPO training for epoch %d as train steps "
83- "(%d) already reached" , epoch , start_step )
84- return
81+ # Fail-friendly, don't train if already trained for this epoch
82+ if start_step >= ((hparams .epochs_num * (epoch + 1 ))):
83+ tf .logging .info ("Skipping PPO training for epoch %d as train steps "
84+ "(%d) already reached" , epoch , start_step )
85+ return
8586
86- for epoch_index in range (hparams .epochs_num ):
87- summary = sess .run (train_summary_op )
88- if summary_writer :
89- summary_writer .add_summary (summary , epoch_index )
90- if (hparams .eval_every_epochs and
91- epoch_index % hparams .eval_every_epochs == 0 ):
92- if summary_writer and summary :
87+ for epoch_index in range (hparams .epochs_num ):
88+ summary = sess .run (train_summary_op )
89+ if summary_writer :
9390 summary_writer .add_summary (summary , epoch_index )
94- else :
95- tf .logging .info ("Eval summary not saved" )
96- if (model_saver and hparams .save_models_every_epochs and
97- (epoch_index % hparams .save_models_every_epochs == 0 or
98- (epoch_index + 1 ) == hparams .epochs_num )):
99- ckpt_path = os .path .join (
100- model_dir , "model.ckpt-{}" .format (epoch_index + 1 + start_step ))
101- model_saver .save (sess , ckpt_path )
91+ if (hparams .eval_every_epochs and
92+ epoch_index % hparams .eval_every_epochs == 0 ):
93+ if summary_writer and summary :
94+ summary_writer .add_summary (summary , epoch_index )
95+ else :
96+ tf .logging .info ("Eval summary not saved" )
97+ if (model_saver and hparams .save_models_every_epochs and
98+ (epoch_index % hparams .save_models_every_epochs == 0 or
99+ (epoch_index + 1 ) == hparams .epochs_num )):
100+ ckpt_path = os .path .join (
101+ model_dir , "model.ckpt-{}" .format (epoch_index + 1 + start_step ))
102+ model_saver .save (sess , ckpt_path )
0 commit comments