@@ -222,6 +222,8 @@ def feed_data(self, data_loader, training=True):
222222
223223 batch_size = data_loader .batch_size
224224
225+ if self .device .type == 'cuda' :
226+ torch .cuda .synchronize ()
225227 end = time .time ()
226228 for i , (src , tgt ) in enumerate (data_loader ):
227229 self .save_counter += 1
@@ -241,12 +243,14 @@ def feed_data(self, data_loader, training=True):
241243 losses_per_sentence .update (loss_per_sentence , batch_size )
242244
243245 # measure elapsed time
246+ if self .device .type == 'cuda' :
247+ torch .cuda .synchronize ()
244248 elapsed = time .time () - end
245249 batch_time .update (elapsed )
246- src_tok_time .update (num_toks ['src' ] / elapsed )
247- tgt_tok_time .update (num_toks ['tgt' ] / elapsed )
250+ src_tok_time .update (num_toks ['src' ] / elapsed , elapsed )
251+ tgt_tok_time .update (num_toks ['tgt' ] / elapsed , elapsed )
248252 tot_num_toks = num_toks ['tgt' ] + num_toks ['src' ]
249- tot_tok_time .update (tot_num_toks / elapsed )
253+ tot_tok_time .update (tot_num_toks / elapsed , elapsed )
250254 self .loss = losses_per_token .avg
251255
252256 if training and i in eval_iters :
@@ -298,6 +302,8 @@ def feed_data(self, data_loader, training=True):
298302 if rank == 0 :
299303 self .save (identifier = identifier )
300304
305+ if self .device .type == 'cuda' :
306+ torch .cuda .synchronize ()
301307 end = time .time ()
302308
303309 tot_tok_time .reduce ('sum' )
0 commit comments