193193 help = 'disable fast prefetcher' )
194194parser .add_argument ('--output' , default = '' , type = str , metavar = 'PATH' ,
195195 help = 'path to output folder (default: none, current dir)' )
196- parser .add_argument ('--eval-metric' , default = 'prec1 ' , type = str , metavar = 'EVAL_METRIC' ,
197- help = 'Best metric (default: "prec1 "' )
196+ parser .add_argument ('--eval-metric' , default = 'top1 ' , type = str , metavar = 'EVAL_METRIC' ,
197+ help = 'Best metric (default: "top1 "' )
198198parser .add_argument ('--tta' , type = int , default = 0 , metavar = 'N' ,
199199 help = 'Test/inference time augmentation (oversampling) factor. 0=None (default: 0)' )
200200parser .add_argument ("--local_rank" , default = 0 , type = int )
@@ -596,8 +596,8 @@ def train_epoch(
596596def validate (model , loader , loss_fn , args , log_suffix = '' ):
597597 batch_time_m = AverageMeter ()
598598 losses_m = AverageMeter ()
599- prec1_m = AverageMeter ()
600- prec5_m = AverageMeter ()
599+ top1_m = AverageMeter ()
600+ top5_m = AverageMeter ()
601601
602602 model .eval ()
603603
@@ -621,20 +621,20 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
621621 target = target [0 :target .size (0 ):reduce_factor ]
622622
623623 loss = loss_fn (output , target )
624- prec1 , prec5 = accuracy (output , target , topk = (1 , 5 ))
624+ acc1 , acc5 = accuracy (output , target , topk = (1 , 5 ))
625625
626626 if args .distributed :
627627 reduced_loss = reduce_tensor (loss .data , args .world_size )
628- prec1 = reduce_tensor (prec1 , args .world_size )
629- prec5 = reduce_tensor (prec5 , args .world_size )
628+ acc1 = reduce_tensor (acc1 , args .world_size )
629+ acc5 = reduce_tensor (acc5 , args .world_size )
630630 else :
631631 reduced_loss = loss .data
632632
633633 torch .cuda .synchronize ()
634634
635635 losses_m .update (reduced_loss .item (), input .size (0 ))
636- prec1_m .update (prec1 .item (), output .size (0 ))
637- prec5_m .update (prec5 .item (), output .size (0 ))
636+ top1_m .update (acc1 .item (), output .size (0 ))
637+ top5_m .update (acc5 .item (), output .size (0 ))
638638
639639 batch_time_m .update (time .time () - end )
640640 end = time .time ()
@@ -644,13 +644,12 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
644644 '{0}: [{1:>4d}/{2}] '
645645 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
646646 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
647- 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
648- 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})' .format (
649- log_name , batch_idx , last_idx ,
650- batch_time = batch_time_m , loss = losses_m ,
651- top1 = prec1_m , top5 = prec5_m ))
647+ 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
648+ 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})' .format (
649+ log_name , batch_idx , last_idx , batch_time = batch_time_m ,
650+ loss = losses_m , top1 = top1_m , top5 = top5_m ))
652651
653- metrics = OrderedDict ([('loss' , losses_m .avg ), ('prec1 ' , prec1_m .avg ), ('prec5 ' , prec5_m .avg )])
652+ metrics = OrderedDict ([('loss' , losses_m .avg ), ('top1 ' , top1_m .avg ), ('top5 ' , top5_m .avg )])
654653
655654 return metrics
656655
0 commit comments