diff --git a/comms/torchcomms/gloo/TorchCommGloo.cpp b/comms/torchcomms/gloo/TorchCommGloo.cpp index 00d5f80a..896d829e 100644 --- a/comms/torchcomms/gloo/TorchCommGloo.cpp +++ b/comms/torchcomms/gloo/TorchCommGloo.cpp @@ -278,6 +278,20 @@ void recvTensor( } } +ReduceOp convertReduceOpToCPU(const ReduceOp& op) { + if (op.type() == ReduceOp::RedOpType::PREMUL_SUM && op.factor().has_value()) { + const auto& factor = *op.factor(); + if (std::holds_alternative(factor)) { + const auto& tensorFactor = std::get(factor); + if (tensorFactor.device().type() == at::kCUDA) { + auto cpuFactor = tensorFactor.to(at::kCPU).contiguous().clone(); + return ReduceOp::make_nccl_premul_sum(PreMulSumFactorT(cpuFactor)); + } + } + } + return ReduceOp(op); +} + } // namespace TorchCommGloo::TorchCommGloo() : device_(at::kCPU) {} @@ -521,7 +535,6 @@ c10::intrusive_ptr TorchCommGloo::broadcast( tracing_->recordEventWithInputOutput("broadcast", root, {tensor}, {tensor}); - // This will synchronize the stream. auto tensorCPU = tensor.to(at::kCPU); return createWork( @@ -532,7 +545,7 @@ c10::intrusive_ptr TorchCommGloo::broadcast( context = context_, tag = nextTag()]() mutable { gloo::BroadcastOptions opts(context); - const auto& scalarType = tensor.scalar_type(); + const auto& scalarType = tensorCPU.scalar_type(); opts.setRoot(root); opts.setTag(tag); if (options.timeout != kNoTimeout) { @@ -562,28 +575,28 @@ c10::intrusive_ptr TorchCommGloo::all_reduce( tracing_->recordEventWithInputOutput("all_reduce", rank_, {tensor}, {tensor}); - // This will synchronize the stream. auto tensorCPU = tensor.to(at::kCPU); + auto opCPU = convertReduceOpToCPU(op); return createWork( [tensor, tensorCPU, - op, + opCPU, options, context = context_, tag = nextTag()]() mutable { gloo::AllreduceOptions opts(context); - const auto& scalarType = tensor.scalar_type(); - opts.setReduceFunction(getFunction(scalarType, op)); + const auto& scalarType = tensorCPU.scalar_type(); + opts.setReduceFunction(getFunction(scalarType, opCPU)); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); } GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensorCPU); - preReduce(tensorCPU, op); + preReduce(tensorCPU, opCPU); gloo::allreduce(opts); - postReduce(tensorCPU, op); + postReduce(tensorCPU, opCPU); if (tensorCPU.device() != tensor.device()) { // This will block the CPU thread so we don't need to synchronize the @@ -606,20 +619,20 @@ c10::intrusive_ptr TorchCommGloo::reduce( tracing_->recordEventWithInputOutput("reduce", root, {tensor}, {tensor}); - // This will synchronize the stream. auto tensorCPU = tensor.to(at::kCPU); + auto opCPU = convertReduceOpToCPU(op); return createWork( [tensor, tensorCPU, root, - op, + opCPU, options, context = context_, tag = nextTag()]() mutable { gloo::ReduceOptions opts(context); - const auto& scalarType = tensor.scalar_type(); - opts.setReduceFunction(getFunction(scalarType, op)); + const auto& scalarType = tensorCPU.scalar_type(); + opts.setReduceFunction(getFunction(scalarType, opCPU)); opts.setRoot(root); opts.setTag(tag); if (options.timeout != kNoTimeout) { @@ -627,9 +640,9 @@ c10::intrusive_ptr TorchCommGloo::reduce( } GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensorCPU); - preReduce(tensorCPU, op); + preReduce(tensorCPU, opCPU); gloo::reduce(opts); - postReduce(tensorCPU, op); + postReduce(tensorCPU, opCPU); if (tensorCPU.device() != tensor.device()) { // This will block the CPU thread so we don't need to synchronize the @@ -667,7 +680,6 @@ c10::intrusive_ptr TorchCommGloo::all_gather( tracing_->recordEventWithInputOutput( "all_gather", rank_, tensor_list, {tensor}); - // Convert tensors to CPU auto tensorCPU = tensor.to(at::kCPU); std::vector tensorListCPU; tensorListCPU.reserve(tensor_list.size()); @@ -685,7 +697,7 @@ c10::intrusive_ptr TorchCommGloo::all_gather( context = context_, tag = nextTag()]() mutable { gloo::AllgatherOptions opts(context); - const auto& scalarType = tensor.scalar_type(); + const auto& scalarType = tensorCPU.scalar_type(); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -745,7 +757,6 @@ c10::intrusive_ptr TorchCommGloo::all_gather_single( tracing_->recordEventWithInputOutput( "all_gather_single", rank_, {input}, {output}); - // Convert tensors to CPU auto inputCPU = input.to(at::kCPU); auto outputCPU = output.to(at::kCPU); @@ -758,7 +769,7 @@ c10::intrusive_ptr TorchCommGloo::all_gather_single( context = context_, tag = nextTag()]() mutable { gloo::AllgatherOptions opts(context); - const auto& scalarType = input.scalar_type(); + const auto& scalarType = inputCPU.scalar_type(); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -846,12 +857,13 @@ c10::intrusive_ptr TorchCommGloo::reduce_scatter_single( // Convert tensors to CPU (noop if already on CPU) auto inputCPU = input.to(at::kCPU); + auto opCPU = convertReduceOpToCPU(op); return createWork( [input, output, inputCPU, - op, + opCPU, options, rank = rank_, context = context_, @@ -861,8 +873,8 @@ c10::intrusive_ptr TorchCommGloo::reduce_scatter_single( // 2. Each rank takes its portion from the result gloo::AllreduceOptions opts(context); - const auto& scalarType = input.scalar_type(); - opts.setReduceFunction(getFunction(scalarType, op)); + const auto& scalarType = inputCPU.scalar_type(); + opts.setReduceFunction(getFunction(scalarType, opCPU)); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -872,9 +884,9 @@ c10::intrusive_ptr TorchCommGloo::reduce_scatter_single( GENERATE_ALL_TYPES(scalarType, setInput, opts, inputCPU); GENERATE_ALL_TYPES(scalarType, setOutput, opts, inputCPU); - preReduce(inputCPU, op); + preReduce(inputCPU, opCPU); gloo::allreduce(opts); - postReduce(inputCPU, op); + postReduce(inputCPU, opCPU); // Extract this rank's portion from the reduced result auto chunkSize = output.numel(); @@ -921,7 +933,7 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_single( context = context_, tag = nextTag()]() mutable { gloo::AlltoallOptions opts(context); - const auto& scalarType = input.scalar_type(); + const auto& scalarType = inputCPU.scalar_type(); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -986,7 +998,6 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( tracing_->recordEventWithInputOutput( "all_to_all_v_single", rank_, {input}, {output}); - // Convert tensors to CPU (noop if already on CPU) auto inputCPU = input.to(at::kCPU); auto outputCPU = output.to(at::kCPU); @@ -1001,7 +1012,7 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( context = context_, tag = nextTag()]() mutable { gloo::AlltoallvOptions opts(context); - const auto& scalarType = input.scalar_type(); + const auto& scalarType = inputCPU.scalar_type(); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -1013,8 +1024,8 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( outputElements.reserve(output_split_sizes.size()); // Calculate number of elements in each dim 0 chunk. - auto dim0Numel = - input.numel() / std::max(input.size(0), static_cast(1)); + auto dim0Numel = inputCPU.numel() / + std::max(inputCPU.size(0), static_cast(1)); for (auto size : input_split_sizes) { inputElements.push_back(static_cast(size) * dim0Numel); } @@ -1171,7 +1182,6 @@ c10::intrusive_ptr TorchCommGloo::scatter( tracing_->recordEventWithInputOutput( "scatter", root, input_tensor_list, {output_tensor}); - // Convert tensors to CPU auto outputCPU = output_tensor.to(at::kCPU); // Only root rank needs to prepare input list @@ -1193,7 +1203,7 @@ c10::intrusive_ptr TorchCommGloo::scatter( context = context_, tag = nextTag()]() mutable { gloo::ScatterOptions opts(context); - const auto& scalarType = output_tensor.scalar_type(); + const auto& scalarType = outputCPU.scalar_type(); opts.setRoot(root); opts.setTag(tag); if (options.timeout != kNoTimeout) { @@ -1248,7 +1258,6 @@ c10::intrusive_ptr TorchCommGloo::gather( tracing_->recordEventWithInputOutput( "gather", root, {input_tensor}, output_tensor_list); - // Convert tensors to CPU auto inputCPU = input_tensor.to(at::kCPU); // Only root rank needs to prepare output @@ -1279,7 +1288,7 @@ c10::intrusive_ptr TorchCommGloo::gather( context = context_, tag = nextTag()]() mutable { gloo::GatherOptions opts(context); - const auto& scalarType = input_tensor.scalar_type(); + const auto& scalarType = inputCPU.scalar_type(); opts.setRoot(root); opts.setTag(tag); if (options.timeout != kNoTimeout) { diff --git a/comms/torchcomms/scripts/run_tests_integration_py.sh b/comms/torchcomms/scripts/run_tests_integration_py.sh index 1d221ca0..5a1e84eb 100755 --- a/comms/torchcomms/scripts/run_tests_integration_py.sh +++ b/comms/torchcomms/scripts/run_tests_integration_py.sh @@ -23,10 +23,16 @@ run_tests export TEST_BACKEND=ncclx run_tests -# Gloo +# Gloo with CPU export TEST_BACKEND=gloo export TEST_DEVICE=cpu export CUDA_VISIBLE_DEVICES="" run_tests unset TEST_DEVICE unset CUDA_VISIBLE_DEVICES + +# Gloo with CUDA +export TEST_BACKEND=gloo +export TEST_DEVICE=cuda +run_tests +unset TEST_DEVICE