@@ -799,6 +799,7 @@ def main():
799799 args ,
800800 amp_autocast = amp_autocast ,
801801 tensorboard_writer = tensorboard_writer ,
802+ epoch = epoch ,
802803 )
803804
804805 if model_ema is not None and not args .model_ema_force_cpu :
@@ -812,6 +813,8 @@ def main():
812813 args ,
813814 amp_autocast = amp_autocast ,
814815 log_suffix = ' (EMA)' ,
816+ tensorboard_writer = tensorboard_writer ,
817+ epoch = epoch ,
815818 )
816819 eval_metrics = ema_eval_metrics
817820
@@ -989,6 +992,7 @@ def validate(
989992 amp_autocast = suppress ,
990993 log_suffix = '' ,
991994 tensorboard_writer = None ,
995+ epoch = None ,
992996
993997):
994998 batch_time_m = utils .AverageMeter ()
@@ -1040,9 +1044,10 @@ def validate(
10401044 batch_time_m .update (time .time () - end )
10411045 end = time .time ()
10421046 if should_log_to_tensorboard (args ):
1043- tensorboard_writer .add_scalar ('val/loss' , losses_m .val , batch_idx )
1044- tensorboard_writer .add_scalar ('val/acc1' , top1_m .val , batch_idx )
1045- tensorboard_writer .add_scalar ('val/acc5' , top5_m .val , batch_idx )
1047+ #by the updates
1048+ tensorboard_writer .add_scalar ('val/loss' , losses_m .val , epoch * last_idx + batch_idx )
1049+ tensorboard_writer .add_scalar ('val/acc1' , top1_m .val , epoch * last_idx + batch_idx )
1050+ tensorboard_writer .add_scalar ('val/acc5' , top5_m .val , epoch * last_idx + batch_idx )
10461051 if utils .is_primary (args ) and (last_batch or batch_idx % args .log_interval == 0 ):
10471052 log_name = 'Test' + log_suffix
10481053 _logger .info (
0 commit comments