1010 from apex .parallel import convert_syncbn_model
1111 has_apex = True
1212except ImportError :
13+ from torch .nn .parallel import DistributedDataParallel as DDP
1314 has_apex = False
1415
1516from timm .data import Dataset , create_loader , resolve_data_config , FastCollateMixup , mixup_target
@@ -169,8 +170,9 @@ def main():
169170 bn_eps = args .bn_eps ,
170171 checkpoint_path = args .initial_checkpoint )
171172
172- logging .info ('Model %s created, param count: %d' %
173- (args .model , sum ([m .numel () for m in model .parameters ()])))
173+ if args .local_rank == 0 :
174+ logging .info ('Model %s created, param count: %d' %
175+ (args .model , sum ([m .numel () for m in model .parameters ()])))
174176
175177 data_config = resolve_data_config (model , args , verbose = args .local_rank == 0 )
176178
@@ -187,36 +189,47 @@ def main():
187189 args .amp = False
188190 model = nn .DataParallel (model , device_ids = list (range (args .num_gpu ))).cuda ()
189191 else :
190- if args .distributed and args .sync_bn and has_apex :
191- model = convert_syncbn_model (model )
192192 model .cuda ()
193193
194194 optimizer = create_optimizer (args , model )
195195 if optimizer_state is not None :
196196 optimizer .load_state_dict (optimizer_state )
197197
198+ use_amp = False
198199 if has_apex and args .amp :
199200 model , optimizer = amp .initialize (model , optimizer , opt_level = 'O1' )
200201 use_amp = True
201- logging .info ('AMP enabled' )
202- else :
203- use_amp = False
204- logging .info ('AMP disabled' )
202+ if args .local_rank == 0 :
203+ logging .info ('NVIDIA APEX {}. AMP {}.' .format (
204+ 'installed' if has_apex else 'not installed' , 'on' if use_amp else 'off' ))
205205
206206 model_ema = None
207207 if args .model_ema :
208+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
208209 model_ema = ModelEma (
209210 model ,
210211 decay = args .model_ema_decay ,
211212 device = 'cpu' if args .model_ema_force_cpu else '' ,
212213 resume = args .resume )
213214
214215 if args .distributed :
215- model = DDP (model , delay_allreduce = True )
216- if model_ema is not None and not args .model_ema_force_cpu :
217- # must also distribute EMA model to allow validation
218- model_ema .ema = DDP (model_ema .ema , delay_allreduce = True )
219- model_ema .ema_has_module = True
216+ if args .sync_bn :
217+ try :
218+ if has_apex :
219+ model = convert_syncbn_model (model )
220+ else :
221+ model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (model )
222+ if args .local_rank == 0 :
223+ logging .info ('Converted model to use Synchronized BatchNorm.' )
224+ except Exception as e :
225+ logging .error ('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' )
226+ if has_apex :
227+ model = DDP (model , delay_allreduce = True )
228+ else :
229+ if args .local_rank == 0 :
230+ logging .info ("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." )
231+ model = DDP (model , device_ids = [args .local_rank ]) # can use device str in Torch >= 1.1
232+ # NOTE: EMA model does not need to be wrapped by DDP
220233
221234 lr_scheduler , num_epochs = create_scheduler (args , optimizer )
222235 if start_epoch > 0 :
0 commit comments