Skip to content

Commit b20bb58

Browse files
committed
Distributed tweaks
* Support PyTorch native DDP as fallback if APEX not present * Support SyncBN for both APEX and Torch native (if torch >= 1.1) * EMA model does not appear to need DDP wrapper, no gradients, updated from wrapped model
1 parent 6fc886a commit b20bb58

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

train.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from apex.parallel import convert_syncbn_model
1111
has_apex = True
1212
except ImportError:
13+
from torch.nn.parallel import DistributedDataParallel as DDP
1314
has_apex = False
1415

1516
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
@@ -169,8 +170,9 @@ def main():
169170
bn_eps=args.bn_eps,
170171
checkpoint_path=args.initial_checkpoint)
171172

172-
logging.info('Model %s created, param count: %d' %
173-
(args.model, sum([m.numel() for m in model.parameters()])))
173+
if args.local_rank == 0:
174+
logging.info('Model %s created, param count: %d' %
175+
(args.model, sum([m.numel() for m in model.parameters()])))
174176

175177
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
176178

@@ -187,36 +189,47 @@ def main():
187189
args.amp = False
188190
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
189191
else:
190-
if args.distributed and args.sync_bn and has_apex:
191-
model = convert_syncbn_model(model)
192192
model.cuda()
193193

194194
optimizer = create_optimizer(args, model)
195195
if optimizer_state is not None:
196196
optimizer.load_state_dict(optimizer_state)
197197

198+
use_amp = False
198199
if has_apex and args.amp:
199200
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
200201
use_amp = True
201-
logging.info('AMP enabled')
202-
else:
203-
use_amp = False
204-
logging.info('AMP disabled')
202+
if args.local_rank == 0:
203+
logging.info('NVIDIA APEX {}. AMP {}.'.format(
204+
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
205205

206206
model_ema = None
207207
if args.model_ema:
208+
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
208209
model_ema = ModelEma(
209210
model,
210211
decay=args.model_ema_decay,
211212
device='cpu' if args.model_ema_force_cpu else '',
212213
resume=args.resume)
213214

214215
if args.distributed:
215-
model = DDP(model, delay_allreduce=True)
216-
if model_ema is not None and not args.model_ema_force_cpu:
217-
# must also distribute EMA model to allow validation
218-
model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
219-
model_ema.ema_has_module = True
216+
if args.sync_bn:
217+
try:
218+
if has_apex:
219+
model = convert_syncbn_model(model)
220+
else:
221+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
222+
if args.local_rank == 0:
223+
logging.info('Converted model to use Synchronized BatchNorm.')
224+
except Exception as e:
225+
logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
226+
if has_apex:
227+
model = DDP(model, delay_allreduce=True)
228+
else:
229+
if args.local_rank == 0:
230+
logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
231+
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
232+
# NOTE: EMA model does not need to be wrapped by DDP
220233

221234
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
222235
if start_epoch > 0:

0 commit comments

Comments
 (0)