@@ -30,6 +30,27 @@ To evaluate all checkpoints in a given directory:
3030 --hparams_set=transformer_big_single_gpu
3131 --source=wmt13_deen.en
3232 --reference=wmt13_deen.de`
33+
34+ In addition to the above-mentioned compulsory parameters,
35+ there are optional parameters:
36+
37+ * bleu_variant: cased (case-sensitive), uncased, both (default).
38+ * translations_dir: Where to store the translated files? Default="translations".
39+ * even_subdir: Where in the model_dir to store the even file? Default="",
40+ which means TensorBoard will show it as the same run as the training, but it will warn
41+ about "more than one metagraph event per run". event_subdir can be used e.g. if running
42+ this script several times with different `--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA"`.
43+ * tag_suffix: Default="", so the tags will be BLEU_cased and BLEU_uncased. Again, tag_suffix
44+ can be used e.g. for different beam sizes if these should be plotted in different graphs.
45+ * min_steps: Don't evaluate checkpoints with less steps.
46+ Default=-1 means check the `last_evaluated_step.txt` file, which contains the number of steps
47+ of the last successfully evaluated checkpoint.
48+ * report_zero: Store BLEU=0 and guess its time based on flags.txt. Default=True.
49+ This is useful, so TensorBoard reports correct relative time for the remaining checkpoints.
50+ This flag is set to False if min_steps is > 0.
51+ * wait_secs: Wait upto N seconds for a new checkpoint. Default=0.
52+ This is useful for continuous evaluation of a running training,
53+ in which case this should be equal to save_checkpoints_secs plus some reserve.
3354"""
3455from __future__ import absolute_import
3556from __future__ import division
@@ -53,7 +74,11 @@ flags.DEFINE_string("translation", None, "Path to the MT system translation file
5374flags .DEFINE_string ("source" , None , "Path to the source-language file to be translated" )
5475flags .DEFINE_string ("reference" , None , "Path to the reference translation file" )
5576flags .DEFINE_string ("translations_dir" , "translations" , "Where to store the translated files" )
56- flags .DEFINE_bool ("report_zero" , True , "Store BLEU=0 and guess its time via flags.txt" )
77+ flags .DEFINE_string ("event_subdir" , "" , "Where in model_dir to store the event file" )
78+ flags .DEFINE_string ("tag_suffix" , "" , "What to add to BLEU_cased and BLEU_uncased tags. Default=''." )
79+ flags .DEFINE_integer ("min_steps" , - 1 , "Don't evaluate checkpoints with less steps." )
80+ flags .DEFINE_integer ("wait_secs" , 0 , "Wait upto N seconds for a new checkpoint, cf. save_checkpoints_secs." )
81+ flags .DEFINE_bool ("report_zero" , None , "Store BLEU=0 and guess its time based on flags.txt" )
5782
5883# options derived from t2t-decode
5984flags .DEFINE_integer ("decode_shards" , 1 , "Number of decoding replicas." )
@@ -70,6 +95,11 @@ flags.DEFINE_string("schedule", "train_and_evaluate",
7095Model = namedtuple ('Model' , 'filename time steps' )
7196
7297
98+ def read_checkpoints_list (model_dir , min_steps ):
99+ models = [Model (x [:- 6 ], os .path .getctime (x ), int (x [:- 6 ].rsplit ('-' )[- 1 ]))
100+ for x in tf .gfile .Glob (os .path .join (model_dir , 'model.ckpt-*.index' ))]
101+ return sorted ((x for x in models if x .steps > min_steps ), key = lambda x : x .steps )
102+
73103def main (_ ):
74104 tf .logging .set_verbosity (tf .logging .INFO )
75105 if FLAGS .translation :
@@ -107,22 +137,43 @@ def main(_):
107137
108138 os .makedirs (FLAGS .translations_dir , exist_ok = True )
109139 translated_base_file = os .path .join (FLAGS .translations_dir , FLAGS .problems )
110- models = [Model (x [:- 6 ], os .path .getctime (x ), int (x [:- 6 ].rsplit ('-' )[- 1 ]))
111- for x in tf .gfile .Glob (os .path .join (model_dir , 'model.ckpt-*.index' ))]
112- models = sorted (models , key = lambda x : x .time )
140+ event_dir = os .path .join (FLAGS .model_dir , FLAGS .event_subdir )
141+ last_step_file = os .path .join (event_dir , 'last_evaluated_step.txt' )
142+ if FLAGS .min_steps == - 1 :
143+ try :
144+ with open (last_step_file ) as ls_file :
145+ FLAGS .min_steps = int (ls_file .read ())
146+ except FileNotFoundError :
147+ FLAGS .min_steps = 0
148+ if FLAGS .report_zero is None :
149+ FLAGS .report_zero = FLAGS .min_steps == 0
150+
151+ models = read_checkpoints_list (model_dir , FLAGS .min_steps )
113152 tf .logging .info ("Found %d models with steps: %s" % (len (models ), ", " .join (str (x .steps ) for x in models )))
114153
115- writer = tf .summary .FileWriter (FLAGS . model_dir )
154+ writer = tf .summary .FileWriter (event_dir )
116155 if FLAGS .report_zero :
117156 start_time = os .path .getctime (os .path .join (model_dir , 'flags.txt' ))
118157 values = []
119158 if FLAGS .bleu_variant in ('uncased' , 'both' ):
120- values .append (tf .Summary .Value (tag = 'BLEU_uncased' , simple_value = 0 ))
159+ values .append (tf .Summary .Value (tag = 'BLEU_uncased' + FLAGS . tag_suffix , simple_value = 0 ))
121160 if FLAGS .bleu_variant in ('cased' , 'both' ):
122- values .append (tf .Summary .Value (tag = 'BLEU_cased' , simple_value = 0 ))
161+ values .append (tf .Summary .Value (tag = 'BLEU_cased' + FLAGS . tag_suffix , simple_value = 0 ))
123162 writer .add_event (tf .summary .Event (summary = tf .Summary (value = values ), wall_time = start_time , step = 0 ))
124163
125- for model in models :
164+ exit_time = time .time () + FLAGS .wait_secs
165+ min_steps = FLAGS .min_steps
166+ while True :
167+ if not models and FLAGS .wait_secs :
168+ tf .logging .info ('All checkpoints evaluated. Waiting till %s if a new checkpoint appears' % time .asctime (time .localtime (exit_time )))
169+ while not models and time .time () < exit_time :
170+ time .sleep (10 )
171+ models = read_checkpoints_list (model_dir , min_steps )
172+ if not models :
173+ return
174+
175+ model = models .pop (0 )
176+ exit_time , min_steps = model .time + FLAGS .wait_secs , model .steps
126177 tf .logging .info ("Evaluating " + model .filename )
127178 out_file = translated_base_file + '-' + str (model .steps )
128179 tf .logging .set_verbosity (tf .logging .ERROR ) # decode_from_file logs all the translations as INFO
@@ -131,15 +182,17 @@ def main(_):
131182 values = []
132183 if FLAGS .bleu_variant in ('uncased' , 'both' ):
133184 bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , out_file , case_sensitive = False )
134- values .append (tf .Summary .Value (tag = 'BLEU_uncased' , simple_value = bleu ))
185+ values .append (tf .Summary .Value (tag = 'BLEU_uncased' + FLAGS . tag_suffix , simple_value = bleu ))
135186 tf .logging .info ("%s: BLEU_uncased = %6.2f" % (model .filename , bleu ))
136187 if FLAGS .bleu_variant in ('cased' , 'both' ):
137188 bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , out_file , case_sensitive = True )
138- values .append (tf .Summary .Value (tag = 'BLEU_cased' , simple_value = bleu ))
189+ values .append (tf .Summary .Value (tag = 'BLEU_cased' + FLAGS . tag_suffix , simple_value = bleu ))
139190 tf .logging .info ("%s: BLEU_cased = %6.2f" % (model .filename , bleu ))
140191 writer .add_event (tf .summary .Event (summary = tf .Summary (value = values ), wall_time = model .time , step = model .steps ))
192+ writer .flush ()
193+ with open (last_step_file , 'w' ) as ls_file :
194+ ls_file .write (str (model .steps ) + '\n ' )
141195
142- writer .flush ()
143196
144197if __name__ == "__main__" :
145198 tf .app .run ()
0 commit comments