Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 41 additions & 32 deletions comms/torchcomms/gloo/TorchCommGloo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,20 @@ void recvTensor(
}
}

ReduceOp convertReduceOpToCPU(const ReduceOp& op) {
if (op.type() == ReduceOp::RedOpType::PREMUL_SUM && op.factor().has_value()) {
const auto& factor = *op.factor();
if (std::holds_alternative<at::Tensor>(factor)) {
const auto& tensorFactor = std::get<at::Tensor>(factor);
if (tensorFactor.device().type() == at::kCUDA) {
auto cpuFactor = tensorFactor.to(at::kCPU).contiguous().clone();
return ReduceOp::make_nccl_premul_sum(PreMulSumFactorT(cpuFactor));
}
}
}
return ReduceOp(op);
}

} // namespace

TorchCommGloo::TorchCommGloo() : device_(at::kCPU) {}
Expand Down Expand Up @@ -521,7 +535,6 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::broadcast(

tracing_->recordEventWithInputOutput("broadcast", root, {tensor}, {tensor});

// This will synchronize the stream.
auto tensorCPU = tensor.to(at::kCPU);

return createWork(
Expand All @@ -532,7 +545,7 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::broadcast(
context = context_,
tag = nextTag()]() mutable {
gloo::BroadcastOptions opts(context);
const auto& scalarType = tensor.scalar_type();
const auto& scalarType = tensorCPU.scalar_type();
opts.setRoot(root);
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
Expand Down Expand Up @@ -562,28 +575,28 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::all_reduce(

tracing_->recordEventWithInputOutput("all_reduce", rank_, {tensor}, {tensor});

// This will synchronize the stream.
auto tensorCPU = tensor.to(at::kCPU);
auto opCPU = convertReduceOpToCPU(op);

return createWork(
[tensor,
tensorCPU,
op,
opCPU,
options,
context = context_,
tag = nextTag()]() mutable {
gloo::AllreduceOptions opts(context);
const auto& scalarType = tensor.scalar_type();
opts.setReduceFunction(getFunction(scalarType, op));
const auto& scalarType = tensorCPU.scalar_type();
opts.setReduceFunction(getFunction(scalarType, opCPU));
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
opts.setTimeout(options.timeout);
}
GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensorCPU);

preReduce(tensorCPU, op);
preReduce(tensorCPU, opCPU);
gloo::allreduce(opts);
postReduce(tensorCPU, op);
postReduce(tensorCPU, opCPU);

if (tensorCPU.device() != tensor.device()) {
// This will block the CPU thread so we don't need to synchronize the
Expand All @@ -606,30 +619,30 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::reduce(

tracing_->recordEventWithInputOutput("reduce", root, {tensor}, {tensor});

// This will synchronize the stream.
auto tensorCPU = tensor.to(at::kCPU);
auto opCPU = convertReduceOpToCPU(op);

return createWork(
[tensor,
tensorCPU,
root,
op,
opCPU,
options,
context = context_,
tag = nextTag()]() mutable {
gloo::ReduceOptions opts(context);
const auto& scalarType = tensor.scalar_type();
opts.setReduceFunction(getFunction(scalarType, op));
const auto& scalarType = tensorCPU.scalar_type();
opts.setReduceFunction(getFunction(scalarType, opCPU));
opts.setRoot(root);
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
opts.setTimeout(options.timeout);
}
GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensorCPU);

preReduce(tensorCPU, op);
preReduce(tensorCPU, opCPU);
gloo::reduce(opts);
postReduce(tensorCPU, op);
postReduce(tensorCPU, opCPU);

if (tensorCPU.device() != tensor.device()) {
// This will block the CPU thread so we don't need to synchronize the
Expand Down Expand Up @@ -667,7 +680,6 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::all_gather(
tracing_->recordEventWithInputOutput(
"all_gather", rank_, tensor_list, {tensor});

// Convert tensors to CPU
auto tensorCPU = tensor.to(at::kCPU);
std::vector<at::Tensor> tensorListCPU;
tensorListCPU.reserve(tensor_list.size());
Expand All @@ -685,7 +697,7 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::all_gather(
context = context_,
tag = nextTag()]() mutable {
gloo::AllgatherOptions opts(context);
const auto& scalarType = tensor.scalar_type();
const auto& scalarType = tensorCPU.scalar_type();
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
opts.setTimeout(options.timeout);
Expand Down Expand Up @@ -745,7 +757,6 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::all_gather_single(
tracing_->recordEventWithInputOutput(
"all_gather_single", rank_, {input}, {output});

// Convert tensors to CPU
auto inputCPU = input.to(at::kCPU);
auto outputCPU = output.to(at::kCPU);

Expand All @@ -758,7 +769,7 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::all_gather_single(
context = context_,
tag = nextTag()]() mutable {
gloo::AllgatherOptions opts(context);
const auto& scalarType = input.scalar_type();
const auto& scalarType = inputCPU.scalar_type();
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
opts.setTimeout(options.timeout);
Expand Down Expand Up @@ -846,12 +857,13 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::reduce_scatter_single(

// Convert tensors to CPU (noop if already on CPU)
auto inputCPU = input.to(at::kCPU);
auto opCPU = convertReduceOpToCPU(op);

return createWork(
[input,
output,
inputCPU,
op,
opCPU,
options,
rank = rank_,
context = context_,
Expand All @@ -861,8 +873,8 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::reduce_scatter_single(
// 2. Each rank takes its portion from the result

gloo::AllreduceOptions opts(context);
const auto& scalarType = input.scalar_type();
opts.setReduceFunction(getFunction(scalarType, op));
const auto& scalarType = inputCPU.scalar_type();
opts.setReduceFunction(getFunction(scalarType, opCPU));
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
opts.setTimeout(options.timeout);
Expand All @@ -872,9 +884,9 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::reduce_scatter_single(
GENERATE_ALL_TYPES(scalarType, setInput, opts, inputCPU);
GENERATE_ALL_TYPES(scalarType, setOutput, opts, inputCPU);

preReduce(inputCPU, op);
preReduce(inputCPU, opCPU);
gloo::allreduce(opts);
postReduce(inputCPU, op);
postReduce(inputCPU, opCPU);

// Extract this rank's portion from the reduced result
auto chunkSize = output.numel();
Expand Down Expand Up @@ -921,7 +933,7 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::all_to_all_single(
context = context_,
tag = nextTag()]() mutable {
gloo::AlltoallOptions opts(context);
const auto& scalarType = input.scalar_type();
const auto& scalarType = inputCPU.scalar_type();
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
opts.setTimeout(options.timeout);
Expand Down Expand Up @@ -986,7 +998,6 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::all_to_all_v_single(
tracing_->recordEventWithInputOutput(
"all_to_all_v_single", rank_, {input}, {output});

// Convert tensors to CPU (noop if already on CPU)
auto inputCPU = input.to(at::kCPU);
auto outputCPU = output.to(at::kCPU);

Expand All @@ -1001,7 +1012,7 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::all_to_all_v_single(
context = context_,
tag = nextTag()]() mutable {
gloo::AlltoallvOptions opts(context);
const auto& scalarType = input.scalar_type();
const auto& scalarType = inputCPU.scalar_type();
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
opts.setTimeout(options.timeout);
Expand All @@ -1013,8 +1024,8 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::all_to_all_v_single(
outputElements.reserve(output_split_sizes.size());

// Calculate number of elements in each dim 0 chunk.
auto dim0Numel =
input.numel() / std::max(input.size(0), static_cast<int64_t>(1));
auto dim0Numel = inputCPU.numel() /
std::max(inputCPU.size(0), static_cast<int64_t>(1));
for (auto size : input_split_sizes) {
inputElements.push_back(static_cast<int64_t>(size) * dim0Numel);
}
Expand Down Expand Up @@ -1171,7 +1182,6 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::scatter(
tracing_->recordEventWithInputOutput(
"scatter", root, input_tensor_list, {output_tensor});

// Convert tensors to CPU
auto outputCPU = output_tensor.to(at::kCPU);

// Only root rank needs to prepare input list
Expand All @@ -1193,7 +1203,7 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::scatter(
context = context_,
tag = nextTag()]() mutable {
gloo::ScatterOptions opts(context);
const auto& scalarType = output_tensor.scalar_type();
const auto& scalarType = outputCPU.scalar_type();
opts.setRoot(root);
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
Expand Down Expand Up @@ -1248,7 +1258,6 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::gather(
tracing_->recordEventWithInputOutput(
"gather", root, {input_tensor}, output_tensor_list);

// Convert tensors to CPU
auto inputCPU = input_tensor.to(at::kCPU);

// Only root rank needs to prepare output
Expand Down Expand Up @@ -1279,7 +1288,7 @@ c10::intrusive_ptr<TorchWork> TorchCommGloo::gather(
context = context_,
tag = nextTag()]() mutable {
gloo::GatherOptions opts(context);
const auto& scalarType = input_tensor.scalar_type();
const auto& scalarType = inputCPU.scalar_type();
opts.setRoot(root);
opts.setTag(tag);
if (options.timeout != kNoTimeout) {
Expand Down
8 changes: 7 additions & 1 deletion comms/torchcomms/scripts/run_tests_integration_py.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,16 @@ run_tests
export TEST_BACKEND=ncclx
run_tests

# Gloo
# Gloo with CPU
export TEST_BACKEND=gloo
export TEST_DEVICE=cpu
export CUDA_VISIBLE_DEVICES=""
run_tests
unset TEST_DEVICE
unset CUDA_VISIBLE_DEVICES

# Gloo with CUDA
export TEST_BACKEND=gloo
export TEST_DEVICE=cuda
run_tests
unset TEST_DEVICE
Loading