@@ -798,6 +798,7 @@ def main():
798798 validate_loss_fn ,
799799 args ,
800800 amp_autocast = amp_autocast ,
801+ tensorboard_writer = tensorboard_writer ,
801802 )
802803
803804 if model_ema is not None and not args .model_ema_force_cpu :
@@ -922,8 +923,8 @@ def train_one_epoch(
922923 batch_time_m .update (time .time () - end )
923924 #write to tensorboard if enabled
924925 if should_log_to_tensorboard (args ):
925- writer .add_scalar ('train/loss' , losses_m .val , num_updates )
926- writer .add_scalar ('train/lr' , optimizer .param_groups [0 ]['lr' ], num_updates )
926+ tensorboard_writer .add_scalar ('train/loss' , losses_m .val , num_updates )
927+ tensorboard_writer .add_scalar ('train/lr' , optimizer .param_groups [0 ]['lr' ], num_updates )
927928 if last_batch or batch_idx % args .log_interval == 0 :
928929 lrl = [param_group ['lr' ] for param_group in optimizer .param_groups ]
929930 lr = sum (lrl ) / len (lrl )
@@ -986,7 +987,9 @@ def validate(
986987 args ,
987988 device = torch .device ('cuda' ),
988989 amp_autocast = suppress ,
989- log_suffix = ''
990+ log_suffix = '' ,
991+ tensorboard_writer = None ,
992+
990993):
991994 batch_time_m = utils .AverageMeter ()
992995 losses_m = utils .AverageMeter ()
@@ -1037,9 +1040,9 @@ def validate(
10371040 batch_time_m .update (time .time () - end )
10381041 end = time .time ()
10391042 if should_log_to_tensorboard (args ):
1040- writer .add_scalar ('val/loss' , losses_m .val , batch_idx )
1041- writer .add_scalar ('val/acc1' , top1_m .val , batch_idx )
1042- writer .add_scalar ('val/acc5' , top5_m .val , batch_idx )
1043+ tensorboard_writer .add_scalar ('val/loss' , losses_m .val , batch_idx )
1044+ tensorboard_writer .add_scalar ('val/acc1' , top1_m .val , batch_idx )
1045+ tensorboard_writer .add_scalar ('val/acc5' , top5_m .val , batch_idx )
10431046 if utils .is_primary (args ) and (last_batch or batch_idx % args .log_interval == 0 ):
10441047 log_name = 'Test' + log_suffix
10451048 _logger .info (
0 commit comments