Skip to content

Commit 370a221

Browse files
committed
Merge: [DLRM/PyT] Stop using apex AMP and DDP
2 parents 9becdf8 + 5bc69ca commit 370a221

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

PyTorch/Recommendation/DLRM/dlrm/cuda_ext/fused_gather_embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""
1818

1919
from absl import logging
20-
from apex import amp
20+
import torch
2121
from torch.autograd import Function
2222

2323
from dlrm.cuda_ext import fused_embedding
@@ -26,12 +26,14 @@
2626
class BuckleEmbeddingFusedGatherFunction(Function):
2727
"""Customized embedding gather """
2828
@staticmethod
29+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
2930
def forward(ctx, embedding, indices, offsets, amp_train):
3031
output = fused_embedding.gather_gpu_fused_fwd(embedding, indices, offsets, amp_train)
3132
ctx.save_for_backward(embedding, indices, offsets)
3233
return output
3334

3435
@staticmethod
36+
@torch.cuda.amp.custom_bwd
3537
def backward(ctx, grad_output):
3638
embedding, indices, offsets = ctx.saved_tensors
3739

@@ -40,4 +42,4 @@ def backward(ctx, grad_output):
4042
return grad_weights, None, None, None
4143

4244

43-
buckle_embedding_fused_gather = amp.float_function(BuckleEmbeddingFusedGatherFunction.apply)
45+
buckle_embedding_fused_gather = BuckleEmbeddingFusedGatherFunction.apply

PyTorch/Recommendation/DLRM/dlrm/cuda_ext/sparse_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import copy
1616

1717
import torch
18-
from apex import amp
18+
from torch.cuda import amp
1919
from dlrm.cuda_ext import sparse_gather
2020
from torch import nn
2121
from torch.autograd import Function
@@ -24,18 +24,18 @@
2424
class EmbeddingGatherFunction(Function):
2525
"""Customized embedding gather with fused plain SGD"""
2626
@staticmethod
27+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
2728
def forward(ctx, embedding, indices):
2829
output = sparse_gather.gather_gpu_fwd(embedding, indices)
2930
ctx.save_for_backward(indices)
3031
ctx.num_features = embedding.size(0)
3132
return output
3233

3334
@staticmethod
35+
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
3436
def backward(ctx, grad_output):
3537
indices = ctx.saved_tensors[0]
36-
3738
grad_embedding = sparse_gather.gather_gpu_bwd(grad_output, indices, ctx.num_features)
38-
3939
return grad_embedding, None
4040

4141

@@ -66,4 +66,4 @@ def forward(self, categorical_inputs):
6666
return embedding_out
6767

6868

69-
embedding_gather = amp.float_function(EmbeddingGatherFunction.apply)
69+
embedding_gather = EmbeddingGatherFunction.apply

PyTorch/Recommendation/DLRM/dlrm/scripts/main.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import sys
1919
from absl import app, flags, logging
20-
from apex import amp, parallel, optimizers as apex_optim
20+
from apex import optimizers as apex_optim
2121

2222
from dlrm.data.feature_spec import FeatureSpec
2323
from dlrm.model.distributed import DistributedDlrm
@@ -500,10 +500,7 @@ def parallelize(model):
500500
if world_size <= 1:
501501
return model
502502

503-
if use_gpu:
504-
model.top_model = parallel.DistributedDataParallel(model.top_model)
505-
else: # Use other backend for CPU
506-
model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model)
503+
model.top_model = torch.nn.parallel.DistributedDataParallel(model.top_model)
507504
return model
508505

509506
if FLAGS.mode == 'test':

0 commit comments

Comments
 (0)