Skip to content

Commit 2a7c251

Browse files
committed
Merge: [NCF/PyT] Stop using deprecated apex AMP and apex DDP
2 parents 370a221 + 05ee986 commit 2a7c251

File tree

2 files changed

+10
-30
lines changed

2 files changed

+10
-30
lines changed

PyTorch/Recommendation/NCF/README.md

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,11 @@ The ability to train deep learning networks with lower precision was introduced
143143
For information about:
144144
- How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) documentation.
145145
- Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
146-
- APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
147146

148147

149148
#### Enabling mixed precision
150149

151-
Using the Automatic Mixed Precision (AMP) package requires two modifications in the source code.
152-
The first one is to initialize the model and the optimizer using the `amp.initialize` function:
153-
```python
154-
model, optimizer = amp.initialize(model, optimizer, opt_level="O2"
155-
keep_batchnorm_fp32=False, loss_scale='dynamic')
156-
```
157-
158-
The second one is to use the AMP's loss scaling context manager:
159-
```python
160-
with amp.scale_loss(loss, optimizer) as scaled_loss:
161-
scaled_loss.backward()
162-
```
150+
Mixed precision training is turned off by default. To turn it on issue the `--amp` flag to the `main.py` script.
163151

164152
#### Enabling TF32
165153

PyTorch/Recommendation/NCF/ncf.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@
4747

4848
import dllogger
4949

50-
from apex.parallel import DistributedDataParallel as DDP
51-
from apex import amp
52-
5350

5451
def synchronized_timestamp():
5552
torch.cuda.synchronize()
@@ -252,12 +249,8 @@ def main():
252249
model = model.cuda()
253250
criterion = criterion.cuda()
254251

255-
if args.amp:
256-
model, optimizer = amp.initialize(model, optimizer, opt_level="O2",
257-
keep_batchnorm_fp32=False, loss_scale='dynamic')
258-
259252
if args.distributed:
260-
model = DDP(model)
253+
model = torch.nn.parallel.DistributedDataParallel(model)
261254

262255
local_batch = args.batch_size // args.world_size
263256
traced_criterion = torch.jit.trace(criterion.forward,
@@ -291,6 +284,7 @@ def main():
291284
best_epoch = 0
292285
best_model_timestamp = synchronized_timestamp()
293286
train_throughputs, eval_throughputs = [], []
287+
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
294288

295289
for epoch in range(args.epochs):
296290

@@ -311,16 +305,14 @@ def main():
311305
label_features = batch_dict[LABEL_CHANNEL_NAME]
312306
label_batch = label_features[label_feature_name]
313307

314-
outputs = model(user_batch, item_batch)
315-
loss = traced_criterion(outputs, label_batch.view(-1, 1)).float()
316-
loss = torch.mean(loss.view(-1), 0)
308+
with torch.cuda.amp.autocast(enabled=args.amp):
309+
outputs = model(user_batch, item_batch)
310+
loss = traced_criterion(outputs, label_batch.view(-1, 1))
311+
loss = torch.mean(loss.float().view(-1), 0)
317312

318-
if args.amp:
319-
with amp.scale_loss(loss, optimizer) as scaled_loss:
320-
scaled_loss.backward()
321-
else:
322-
loss.backward()
323-
optimizer.step()
313+
scaler.scale(loss).backward()
314+
scaler.step(optimizer)
315+
scaler.update()
324316

325317
for p in model.parameters():
326318
p.grad = None

0 commit comments

Comments
 (0)