@@ -214,6 +214,7 @@ def run(exe,
214214 last_step = args .last_step_of_checkpoint
215215 train_iter = 0
216216 epoch = 0
217+ train_time_raw = 0
217218 if progress is None :
218219 progress = dict ()
219220 else :
@@ -229,9 +230,12 @@ def run(exe,
229230 f"Only { max_steps - last_step } steps will be performed in this run due to the limit of --max-steps."
230231 )
231232 else :
232- max_steps = args .steps_this_run + last_step
233+ steps_this_run = args .steps_this_run
234+ if args .benchmark :
235+ steps_this_run = min (steps_this_run , args .benchmark_warmup_steps + args .benchmark_steps )
236+ max_steps = steps_this_run + last_step
233237 logging .warning (
234- f"{ args . steps_this_run } steps will be performed in this run." )
238+ f"{ steps_this_run } steps will be performed in this run." )
235239
236240 total_samples = 0
237241 raw_train_start = time .time ()
@@ -272,6 +276,7 @@ def run(exe,
272276
273277 if train_iter % (save_steps * gradient_merge_steps
274278 ) == 0 or global_step >= max_steps :
279+ train_time_raw = time .time () - raw_train_start
275280 if trainer_id == 0 :
276281 model_path = os .path .join (
277282 args .output_dir , args .bert_model , "phase1"
@@ -287,9 +292,7 @@ def run(exe,
287292 if len (most_recent_ckpts_paths ) > 3 :
288293 ckpt_to_be_removed = most_recent_ckpts_paths .pop (0 )
289294 shutil .rmtree (ckpt_to_be_removed )
290- if (global_step >= max_steps ) or (
291- args .benchmark and global_step >=
292- args .benchmark_steps + args .benchmark_warmup_steps ):
293- train_time_raw = time .time () - raw_train_start
294- return global_step , loss_return [0 ].item (), train_time_raw
295+ if global_step >= max_steps :
296+ actual_steps_this_run = global_step - last_step
297+ return global_step , actual_steps_this_run , loss_return [0 ].item (), train_time_raw
295298 epoch += 1
0 commit comments