@@ -566,6 +566,11 @@ def main():
566566 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
567567 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.' )
568568
569+ model_patch_size = None
570+ if args .naflex_loader :
571+ # NaFlexVit models have embeds.patch_size. Needs to be extracted here before mutating the model.
572+ model_patch_size = getattr (getattr (model , "embeds" , None ), "patch_size" , None )
573+
569574 if args .torchscript :
570575 assert not args .torchcompile
571576 assert not use_amp == 'apex' , 'Cannot use APEX AMP with torchscripted model'
@@ -762,7 +767,6 @@ def main():
762767 )
763768
764769 naflex_mode = False
765- model_patch_size = None
766770 if args .naflex_loader :
767771 if utils .is_primary (args ):
768772 _logger .info ('Using NaFlex loader' )
@@ -775,11 +779,8 @@ def main():
775779 mixup_args .pop ('cutmix_minmax' ) # not supported
776780 naflex_mixup_fn = NaFlexMixup (** mixup_args )
777781
778- # Extract model's patch size for NaFlex mode
779- if hasattr (model , 'embeds' ) and hasattr (model .embeds , 'patch_size' ):
780- # NaFlexVit models have embeds.patch_size
781- model_patch_size = model .embeds .patch_size
782- else :
782+ # Check if we have model's patch size for NaFlex mode
783+ if model_patch_size is None :
783784 # Fallback to default
784785 model_patch_size = (16 , 16 )
785786 if utils .is_primary (args ):
@@ -1197,6 +1198,7 @@ def _backward(_loss):
11971198 dist_scale = args .world_size * batch_size / global_batch_size
11981199 else :
11991200 dist_scale = None
1201+ global_batch_size = batch_size
12001202
12011203 if has_no_sync and not need_update :
12021204 with model .no_sync ():
@@ -1212,7 +1214,10 @@ def _backward(_loss):
12121214 scaled_loss *= dist_scale
12131215 _backward (scaled_loss )
12141216 else :
1215- batch_size = input .shape [0 ]
1217+ global_batch_size = batch_size = input .shape [0 ]
1218+ if args .distributed :
1219+ global_batch_size *= args .world_size
1220+
12161221 if has_no_sync and not need_update :
12171222 with model .no_sync ():
12181223 loss = _forward ()
@@ -1222,7 +1227,7 @@ def _backward(_loss):
12221227 _backward (loss )
12231228
12241229 losses_m .update (loss .item () * accum_steps , batch_size )
1225- update_sample_count += batch_size
1230+ update_sample_count += global_batch_size
12261231
12271232 if not need_update :
12281233 data_start_time = time .time ()
@@ -1240,7 +1245,7 @@ def _backward(_loss):
12401245 torch .npu .synchronize ()
12411246 time_now = time .time ()
12421247
1243- update_time_m .update (( time .time () - update_start_time ) / update_sample_count , update_sample_count )
1248+ update_time_m .update (time .time () - update_start_time )
12441249 update_start_time = time_now
12451250
12461251 if update_idx % args .log_interval == 0 :
@@ -1252,15 +1257,14 @@ def _backward(_loss):
12521257 # synchronize current step and avg loss, each process keeps its own running avg
12531258 loss_avg = utils .reduce_tensor (loss .new ([loss_avg ]), args .world_size ).item ()
12541259 loss_now = utils .reduce_tensor (loss .new ([loss_now ]), args .world_size ).item ()
1255- update_sample_count *= args .world_size
12561260
12571261 if utils .is_primary (args ):
12581262 _logger .info (
12591263 f'Train: { epoch } [{ update_idx :>4d} /{ updates_per_epoch } '
12601264 f'({ 100. * (update_idx + 1 ) / updates_per_epoch :>3.0f} %)] '
12611265 f'Loss: { loss_now :#.3g} ({ loss_avg :#.3g} ) '
1262- f'Time: { update_time_m .val :.3f} s, { 1 / update_time_m .val :>7.2f} /s '
1263- f'({ update_time_m .avg :.3f} s, { 1 / update_time_m .avg :>7.2f} /s) '
1266+ f'Time: { update_time_m .val :.3f} s, { update_sample_count / update_time_m .val :>7.2f} /s '
1267+ f'({ update_time_m .avg :.3f} s, { update_sample_count / update_time_m .avg :>7.2f} /s) '
12641268 f'LR: { lr :.3e} '
12651269 f'Data: { data_time_m .val :.3f} ({ data_time_m .avg :.3f} )'
12661270 )
0 commit comments