From bd2f9699b4dd8a76889dd2974f3c8537eec12243 Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Thu, 13 Nov 2025 18:22:55 +0530 Subject: [PATCH 1/4] support for testing gloo with cuda tensors --- comms/torchcomms/gloo/TorchCommGloo.cpp | 158 +++++++++--------- .../scripts/run_tests_integration_py.sh | 8 +- 2 files changed, 83 insertions(+), 83 deletions(-) diff --git a/comms/torchcomms/gloo/TorchCommGloo.cpp b/comms/torchcomms/gloo/TorchCommGloo.cpp index 00d5f80a..472debee 100644 --- a/comms/torchcomms/gloo/TorchCommGloo.cpp +++ b/comms/torchcomms/gloo/TorchCommGloo.cpp @@ -278,6 +278,20 @@ void recvTensor( } } +ReduceOp convertReduceOpToCPU(const ReduceOp& op) { + if (op.type() == ReduceOp::RedOpType::PREMUL_SUM && op.factor().has_value()) { + const auto& factor = *op.factor(); + if (std::holds_alternative(factor)) { + const auto& tensorFactor = std::get(factor); + if (tensorFactor.device().type() == at::kCUDA) { + auto cpuFactor = tensorFactor.to(at::kCPU).contiguous().clone(); + return ReduceOp::make_nccl_premul_sum(PreMulSumFactorT(cpuFactor)); + } + } + } + return ReduceOp(op); +} + } // namespace TorchCommGloo::TorchCommGloo() : device_(at::kCPU) {} @@ -443,12 +457,15 @@ c10::intrusive_ptr TorchCommGloo::recv( tracing_->recordEventWithInputOutput("recv", src, {tensor}, {tensor}); + auto originalDevice = tensor.device(); + // Convert tensor to CPU for Gloo compatibility auto tensorCPU = tensor.to(at::kCPU); return createWork( [tensor, tensorCPU, + originalDevice, src, options, context = context_, @@ -465,10 +482,7 @@ c10::intrusive_ptr TorchCommGloo::recv( tag, options.timeout); - if (tensorCPU.device() != tensor.device()) { - // Copy back to original device if needed - tensor.copy_(tensorCPU); - } + tensor.copy_(tensorCPU); }, async_op); } @@ -521,18 +535,19 @@ c10::intrusive_ptr TorchCommGloo::broadcast( tracing_->recordEventWithInputOutput("broadcast", root, {tensor}, {tensor}); - // This will synchronize the stream. + auto originalDevice = tensor.device(); auto tensorCPU = tensor.to(at::kCPU); return createWork( [tensor, tensorCPU, + originalDevice, root, options, context = context_, tag = nextTag()]() mutable { gloo::BroadcastOptions opts(context); - const auto& scalarType = tensor.scalar_type(); + const auto& scalarType = tensorCPU.scalar_type(); opts.setRoot(root); opts.setTag(tag); if (options.timeout != kNoTimeout) { @@ -542,11 +557,7 @@ c10::intrusive_ptr TorchCommGloo::broadcast( gloo::broadcast(opts); - if (tensorCPU.device() != tensor.device()) { - // This will block the CPU thread so we don't need to synchronize the - // streams. - tensor.copy_(tensorCPU); - } + tensor.copy_(tensorCPU); }, async_op); } @@ -562,34 +573,32 @@ c10::intrusive_ptr TorchCommGloo::all_reduce( tracing_->recordEventWithInputOutput("all_reduce", rank_, {tensor}, {tensor}); - // This will synchronize the stream. - auto tensorCPU = tensor.to(at::kCPU); + auto originalDevice = tensor.device(); + auto tensorCPU = tensor.to(at::kCPU).contiguous().clone(); + auto opCPU = convertReduceOpToCPU(op); return createWork( [tensor, tensorCPU, - op, + originalDevice, + opCPU, options, context = context_, tag = nextTag()]() mutable { gloo::AllreduceOptions opts(context); - const auto& scalarType = tensor.scalar_type(); - opts.setReduceFunction(getFunction(scalarType, op)); + const auto& scalarType = tensorCPU.scalar_type(); + opts.setReduceFunction(getFunction(scalarType, opCPU)); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); } GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensorCPU); - preReduce(tensorCPU, op); + preReduce(tensorCPU, opCPU); gloo::allreduce(opts); - postReduce(tensorCPU, op); + postReduce(tensorCPU, opCPU); - if (tensorCPU.device() != tensor.device()) { - // This will block the CPU thread so we don't need to synchronize the - // streams. - tensor.copy_(tensorCPU); - } + tensor.copy_(tensorCPU); }, async_op); } @@ -606,20 +615,22 @@ c10::intrusive_ptr TorchCommGloo::reduce( tracing_->recordEventWithInputOutput("reduce", root, {tensor}, {tensor}); - // This will synchronize the stream. - auto tensorCPU = tensor.to(at::kCPU); + auto originalDevice = tensor.device(); + auto tensorCPU = tensor.to(at::kCPU).contiguous().clone(); + auto opCPU = convertReduceOpToCPU(op); return createWork( [tensor, tensorCPU, + originalDevice, root, - op, + opCPU, options, context = context_, tag = nextTag()]() mutable { gloo::ReduceOptions opts(context); - const auto& scalarType = tensor.scalar_type(); - opts.setReduceFunction(getFunction(scalarType, op)); + const auto& scalarType = tensorCPU.scalar_type(); + opts.setReduceFunction(getFunction(scalarType, opCPU)); opts.setRoot(root); opts.setTag(tag); if (options.timeout != kNoTimeout) { @@ -627,15 +638,11 @@ c10::intrusive_ptr TorchCommGloo::reduce( } GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensorCPU); - preReduce(tensorCPU, op); + preReduce(tensorCPU, opCPU); gloo::reduce(opts); - postReduce(tensorCPU, op); + postReduce(tensorCPU, opCPU); - if (tensorCPU.device() != tensor.device()) { - // This will block the CPU thread so we don't need to synchronize the - // streams. - tensor.copy_(tensorCPU); - } + tensor.copy_(tensorCPU); }, async_op); } @@ -667,7 +674,7 @@ c10::intrusive_ptr TorchCommGloo::all_gather( tracing_->recordEventWithInputOutput( "all_gather", rank_, tensor_list, {tensor}); - // Convert tensors to CPU + auto originalDevice = tensor.device(); auto tensorCPU = tensor.to(at::kCPU); std::vector tensorListCPU; tensorListCPU.reserve(tensor_list.size()); @@ -680,12 +687,13 @@ c10::intrusive_ptr TorchCommGloo::all_gather( tensor_list, tensorCPU, tensorListCPU, + originalDevice, options, size = comm_size_, context = context_, tag = nextTag()]() mutable { gloo::AllgatherOptions opts(context); - const auto& scalarType = tensor.scalar_type(); + const auto& scalarType = tensorCPU.scalar_type(); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -709,11 +717,9 @@ c10::intrusive_ptr TorchCommGloo::all_gather( tensorListCPU[i].copy_(chunk); } - // Copy results back to original device if needed + // Copy results back to original tensors for (size_t i = 0; i < tensorListCPU.size(); ++i) { - if (tensorListCPU[i].device() != tensor_list[i].device()) { - tensor_list[i].copy_(tensorListCPU[i]); - } + tensor_list[i].copy_(tensorListCPU[i]); } }, async_op); @@ -745,7 +751,7 @@ c10::intrusive_ptr TorchCommGloo::all_gather_single( tracing_->recordEventWithInputOutput( "all_gather_single", rank_, {input}, {output}); - // Convert tensors to CPU + auto originalDevice = output.device(); auto inputCPU = input.to(at::kCPU); auto outputCPU = output.to(at::kCPU); @@ -754,11 +760,12 @@ c10::intrusive_ptr TorchCommGloo::all_gather_single( output, inputCPU, outputCPU, + originalDevice, options, context = context_, tag = nextTag()]() mutable { gloo::AllgatherOptions opts(context); - const auto& scalarType = input.scalar_type(); + const auto& scalarType = inputCPU.scalar_type(); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -770,11 +777,7 @@ c10::intrusive_ptr TorchCommGloo::all_gather_single( gloo::allgather(opts); - if (outputCPU.device() != output.device()) { - // This will block the CPU thread so we don't need to synchronize the - // streams. - output.copy_(outputCPU); - } + output.copy_(outputCPU); }, async_op); } @@ -845,13 +848,14 @@ c10::intrusive_ptr TorchCommGloo::reduce_scatter_single( "reduce_scatter_single", rank_, {input}, {output}); // Convert tensors to CPU (noop if already on CPU) - auto inputCPU = input.to(at::kCPU); + auto inputCPU = input.to(at::kCPU).contiguous().clone(); + auto opCPU = convertReduceOpToCPU(op); return createWork( [input, output, inputCPU, - op, + opCPU, options, rank = rank_, context = context_, @@ -861,8 +865,8 @@ c10::intrusive_ptr TorchCommGloo::reduce_scatter_single( // 2. Each rank takes its portion from the result gloo::AllreduceOptions opts(context); - const auto& scalarType = input.scalar_type(); - opts.setReduceFunction(getFunction(scalarType, op)); + const auto& scalarType = inputCPU.scalar_type(); + opts.setReduceFunction(getFunction(scalarType, opCPU)); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -872,9 +876,9 @@ c10::intrusive_ptr TorchCommGloo::reduce_scatter_single( GENERATE_ALL_TYPES(scalarType, setInput, opts, inputCPU); GENERATE_ALL_TYPES(scalarType, setOutput, opts, inputCPU); - preReduce(inputCPU, op); + preReduce(inputCPU, opCPU); gloo::allreduce(opts); - postReduce(inputCPU, op); + postReduce(inputCPU, opCPU); // Extract this rank's portion from the reduced result auto chunkSize = output.numel(); @@ -908,7 +912,7 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_single( tracing_->recordEventWithInputOutput( "all_to_all_single", rank_, {input}, {output}); - // Convert tensors to CPU + auto originalDevice = output.device(); auto inputCPU = input.to(at::kCPU); auto outputCPU = output.to(at::kCPU); @@ -917,11 +921,12 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_single( output, inputCPU, outputCPU, + originalDevice, options, context = context_, tag = nextTag()]() mutable { gloo::AlltoallOptions opts(context); - const auto& scalarType = input.scalar_type(); + const auto& scalarType = inputCPU.scalar_type(); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -933,11 +938,7 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_single( gloo::alltoall(opts); - if (outputCPU.device() != output.device()) { - // This will block the CPU thread so we don't need to synchronize the - // streams. - output.copy_(outputCPU); - } + output.copy_(outputCPU); }, async_op); } @@ -986,7 +987,7 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( tracing_->recordEventWithInputOutput( "all_to_all_v_single", rank_, {input}, {output}); - // Convert tensors to CPU (noop if already on CPU) + auto originalDevice = output.device(); auto inputCPU = input.to(at::kCPU); auto outputCPU = output.to(at::kCPU); @@ -995,13 +996,14 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( output, inputCPU, outputCPU, + originalDevice, input_split_sizes, output_split_sizes, options, context = context_, tag = nextTag()]() mutable { gloo::AlltoallvOptions opts(context); - const auto& scalarType = input.scalar_type(); + const auto& scalarType = inputCPU.scalar_type(); opts.setTag(tag); if (options.timeout != kNoTimeout) { opts.setTimeout(options.timeout); @@ -1014,7 +1016,7 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( // Calculate number of elements in each dim 0 chunk. auto dim0Numel = - input.numel() / std::max(input.size(0), static_cast(1)); + inputCPU.numel() / std::max(inputCPU.size(0), static_cast(1)); for (auto size : input_split_sizes) { inputElements.push_back(static_cast(size) * dim0Numel); } @@ -1030,11 +1032,7 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( gloo::alltoallv(opts); - if (outputCPU.device() != output.device()) { - // This will block the CPU thread so we don't need to synchronize the - // streams. - output.copy_(outputCPU); - } + output.copy_(outputCPU); }, async_op); } @@ -1171,7 +1169,7 @@ c10::intrusive_ptr TorchCommGloo::scatter( tracing_->recordEventWithInputOutput( "scatter", root, input_tensor_list, {output_tensor}); - // Convert tensors to CPU + auto originalDevice = output_tensor.device(); auto outputCPU = output_tensor.to(at::kCPU); // Only root rank needs to prepare input list @@ -1186,6 +1184,7 @@ c10::intrusive_ptr TorchCommGloo::scatter( return createWork( [output_tensor, outputCPU, + originalDevice, inputListCPU, root, options, @@ -1193,7 +1192,7 @@ c10::intrusive_ptr TorchCommGloo::scatter( context = context_, tag = nextTag()]() mutable { gloo::ScatterOptions opts(context); - const auto& scalarType = output_tensor.scalar_type(); + const auto& scalarType = outputCPU.scalar_type(); opts.setRoot(root); opts.setTag(tag); if (options.timeout != kNoTimeout) { @@ -1210,11 +1209,7 @@ c10::intrusive_ptr TorchCommGloo::scatter( gloo::scatter(opts); - if (outputCPU.device() != output_tensor.device()) { - // This will block the CPU thread so we don't need to synchronize the - // streams. - output_tensor.copy_(outputCPU); - } + output_tensor.copy_(outputCPU); }, async_op); } @@ -1248,7 +1243,7 @@ c10::intrusive_ptr TorchCommGloo::gather( tracing_->recordEventWithInputOutput( "gather", root, {input_tensor}, output_tensor_list); - // Convert tensors to CPU + auto originalDevice = input_tensor.device(); auto inputCPU = input_tensor.to(at::kCPU); // Only root rank needs to prepare output @@ -1272,6 +1267,7 @@ c10::intrusive_ptr TorchCommGloo::gather( inputCPU, outputConcatCPU, outputListCPU, + originalDevice, root, options, rank = rank_, @@ -1279,7 +1275,7 @@ c10::intrusive_ptr TorchCommGloo::gather( context = context_, tag = nextTag()]() mutable { gloo::GatherOptions opts(context); - const auto& scalarType = input_tensor.scalar_type(); + const auto& scalarType = inputCPU.scalar_type(); opts.setRoot(root); opts.setTag(tag); if (options.timeout != kNoTimeout) { @@ -1305,11 +1301,9 @@ c10::intrusive_ptr TorchCommGloo::gather( outputListCPU[i].copy_(chunk); } - // Copy results back to original device if needed + // Copy results back to original tensors for (size_t i = 0; i < outputListCPU.size(); ++i) { - if (outputListCPU[i].device() != output_tensor_list[i].device()) { - output_tensor_list[i].copy_(outputListCPU[i]); - } + output_tensor_list[i].copy_(outputListCPU[i]); } } }, diff --git a/comms/torchcomms/scripts/run_tests_integration_py.sh b/comms/torchcomms/scripts/run_tests_integration_py.sh index 1d221ca0..5a1e84eb 100755 --- a/comms/torchcomms/scripts/run_tests_integration_py.sh +++ b/comms/torchcomms/scripts/run_tests_integration_py.sh @@ -23,10 +23,16 @@ run_tests export TEST_BACKEND=ncclx run_tests -# Gloo +# Gloo with CPU export TEST_BACKEND=gloo export TEST_DEVICE=cpu export CUDA_VISIBLE_DEVICES="" run_tests unset TEST_DEVICE unset CUDA_VISIBLE_DEVICES + +# Gloo with CUDA +export TEST_BACKEND=gloo +export TEST_DEVICE=cuda +run_tests +unset TEST_DEVICE From 20eb5792df75114e5321570bff8749ac9edcc885 Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Sat, 15 Nov 2025 22:55:16 +0530 Subject: [PATCH 2/4] removing unused variable and unecessary copy --- comms/torchcomms/gloo/TorchCommGloo.cpp | 79 ++++++++++++++----------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/comms/torchcomms/gloo/TorchCommGloo.cpp b/comms/torchcomms/gloo/TorchCommGloo.cpp index 472debee..ba5d1a40 100644 --- a/comms/torchcomms/gloo/TorchCommGloo.cpp +++ b/comms/torchcomms/gloo/TorchCommGloo.cpp @@ -457,15 +457,12 @@ c10::intrusive_ptr TorchCommGloo::recv( tracing_->recordEventWithInputOutput("recv", src, {tensor}, {tensor}); - auto originalDevice = tensor.device(); - // Convert tensor to CPU for Gloo compatibility auto tensorCPU = tensor.to(at::kCPU); return createWork( [tensor, tensorCPU, - originalDevice, src, options, context = context_, @@ -482,7 +479,10 @@ c10::intrusive_ptr TorchCommGloo::recv( tag, options.timeout); - tensor.copy_(tensorCPU); + if (tensorCPU.device() != tensor.device()) { + // Copy back to original device if needed + tensor.copy_(tensorCPU); + } }, async_op); } @@ -535,13 +535,11 @@ c10::intrusive_ptr TorchCommGloo::broadcast( tracing_->recordEventWithInputOutput("broadcast", root, {tensor}, {tensor}); - auto originalDevice = tensor.device(); auto tensorCPU = tensor.to(at::kCPU); return createWork( [tensor, tensorCPU, - originalDevice, root, options, context = context_, @@ -557,7 +555,11 @@ c10::intrusive_ptr TorchCommGloo::broadcast( gloo::broadcast(opts); - tensor.copy_(tensorCPU); + if (tensorCPU.device() != tensor.device()) { + // This will block the CPU thread so we don't need to synchronize the + // streams. + tensor.copy_(tensorCPU); + } }, async_op); } @@ -573,14 +575,12 @@ c10::intrusive_ptr TorchCommGloo::all_reduce( tracing_->recordEventWithInputOutput("all_reduce", rank_, {tensor}, {tensor}); - auto originalDevice = tensor.device(); - auto tensorCPU = tensor.to(at::kCPU).contiguous().clone(); + auto tensorCPU = tensor.to(at::kCPU); auto opCPU = convertReduceOpToCPU(op); return createWork( [tensor, tensorCPU, - originalDevice, opCPU, options, context = context_, @@ -598,7 +598,11 @@ c10::intrusive_ptr TorchCommGloo::all_reduce( gloo::allreduce(opts); postReduce(tensorCPU, opCPU); - tensor.copy_(tensorCPU); + if (tensorCPU.device() != tensor.device()) { + // This will block the CPU thread so we don't need to synchronize the + // streams. + tensor.copy_(tensorCPU); + } }, async_op); } @@ -615,14 +619,12 @@ c10::intrusive_ptr TorchCommGloo::reduce( tracing_->recordEventWithInputOutput("reduce", root, {tensor}, {tensor}); - auto originalDevice = tensor.device(); - auto tensorCPU = tensor.to(at::kCPU).contiguous().clone(); + auto tensorCPU = tensor.to(at::kCPU); auto opCPU = convertReduceOpToCPU(op); return createWork( [tensor, tensorCPU, - originalDevice, root, opCPU, options, @@ -642,7 +644,11 @@ c10::intrusive_ptr TorchCommGloo::reduce( gloo::reduce(opts); postReduce(tensorCPU, opCPU); - tensor.copy_(tensorCPU); + if (tensorCPU.device() != tensor.device()) { + // This will block the CPU thread so we don't need to synchronize the + // streams. + tensor.copy_(tensorCPU); + } }, async_op); } @@ -674,7 +680,6 @@ c10::intrusive_ptr TorchCommGloo::all_gather( tracing_->recordEventWithInputOutput( "all_gather", rank_, tensor_list, {tensor}); - auto originalDevice = tensor.device(); auto tensorCPU = tensor.to(at::kCPU); std::vector tensorListCPU; tensorListCPU.reserve(tensor_list.size()); @@ -687,7 +692,6 @@ c10::intrusive_ptr TorchCommGloo::all_gather( tensor_list, tensorCPU, tensorListCPU, - originalDevice, options, size = comm_size_, context = context_, @@ -717,7 +721,7 @@ c10::intrusive_ptr TorchCommGloo::all_gather( tensorListCPU[i].copy_(chunk); } - // Copy results back to original tensors + // Copy results back to original tensors (works for both CPU and CUDA) for (size_t i = 0; i < tensorListCPU.size(); ++i) { tensor_list[i].copy_(tensorListCPU[i]); } @@ -751,7 +755,6 @@ c10::intrusive_ptr TorchCommGloo::all_gather_single( tracing_->recordEventWithInputOutput( "all_gather_single", rank_, {input}, {output}); - auto originalDevice = output.device(); auto inputCPU = input.to(at::kCPU); auto outputCPU = output.to(at::kCPU); @@ -760,7 +763,6 @@ c10::intrusive_ptr TorchCommGloo::all_gather_single( output, inputCPU, outputCPU, - originalDevice, options, context = context_, tag = nextTag()]() mutable { @@ -777,7 +779,11 @@ c10::intrusive_ptr TorchCommGloo::all_gather_single( gloo::allgather(opts); - output.copy_(outputCPU); + if (outputCPU.device() != output.device()) { + // This will block the CPU thread so we don't need to synchronize the + // streams. + output.copy_(outputCPU); + } }, async_op); } @@ -848,7 +854,7 @@ c10::intrusive_ptr TorchCommGloo::reduce_scatter_single( "reduce_scatter_single", rank_, {input}, {output}); // Convert tensors to CPU (noop if already on CPU) - auto inputCPU = input.to(at::kCPU).contiguous().clone(); + auto inputCPU = input.to(at::kCPU); auto opCPU = convertReduceOpToCPU(op); return createWork( @@ -912,7 +918,7 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_single( tracing_->recordEventWithInputOutput( "all_to_all_single", rank_, {input}, {output}); - auto originalDevice = output.device(); + // Convert tensors to CPU auto inputCPU = input.to(at::kCPU); auto outputCPU = output.to(at::kCPU); @@ -921,7 +927,6 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_single( output, inputCPU, outputCPU, - originalDevice, options, context = context_, tag = nextTag()]() mutable { @@ -938,7 +943,11 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_single( gloo::alltoall(opts); - output.copy_(outputCPU); + if (outputCPU.device() != output.device()) { + // This will block the CPU thread so we don't need to synchronize the + // streams. + output.copy_(outputCPU); + } }, async_op); } @@ -987,7 +996,6 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( tracing_->recordEventWithInputOutput( "all_to_all_v_single", rank_, {input}, {output}); - auto originalDevice = output.device(); auto inputCPU = input.to(at::kCPU); auto outputCPU = output.to(at::kCPU); @@ -996,7 +1004,6 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( output, inputCPU, outputCPU, - originalDevice, input_split_sizes, output_split_sizes, options, @@ -1032,7 +1039,11 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( gloo::alltoallv(opts); - output.copy_(outputCPU); + if (outputCPU.device() != output.device()) { + // This will block the CPU thread so we don't need to synchronize the + // streams. + output.copy_(outputCPU); + } }, async_op); } @@ -1169,7 +1180,6 @@ c10::intrusive_ptr TorchCommGloo::scatter( tracing_->recordEventWithInputOutput( "scatter", root, input_tensor_list, {output_tensor}); - auto originalDevice = output_tensor.device(); auto outputCPU = output_tensor.to(at::kCPU); // Only root rank needs to prepare input list @@ -1184,7 +1194,6 @@ c10::intrusive_ptr TorchCommGloo::scatter( return createWork( [output_tensor, outputCPU, - originalDevice, inputListCPU, root, options, @@ -1209,7 +1218,11 @@ c10::intrusive_ptr TorchCommGloo::scatter( gloo::scatter(opts); - output_tensor.copy_(outputCPU); + if (outputCPU.device() != output_tensor.device()) { + // This will block the CPU thread so we don't need to synchronize the + // streams. + output_tensor.copy_(outputCPU); + } }, async_op); } @@ -1243,7 +1256,6 @@ c10::intrusive_ptr TorchCommGloo::gather( tracing_->recordEventWithInputOutput( "gather", root, {input_tensor}, output_tensor_list); - auto originalDevice = input_tensor.device(); auto inputCPU = input_tensor.to(at::kCPU); // Only root rank needs to prepare output @@ -1267,7 +1279,6 @@ c10::intrusive_ptr TorchCommGloo::gather( inputCPU, outputConcatCPU, outputListCPU, - originalDevice, root, options, rank = rank_, @@ -1301,7 +1312,7 @@ c10::intrusive_ptr TorchCommGloo::gather( outputListCPU[i].copy_(chunk); } - // Copy results back to original tensors + // Copy results back to original tensors (works for both CPU and CUDA) for (size_t i = 0; i < outputListCPU.size(); ++i) { output_tensor_list[i].copy_(outputListCPU[i]); } From eeb6e9296a24874bccd72da6bb88c96d4890833d Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Sat, 15 Nov 2025 23:03:37 +0530 Subject: [PATCH 3/4] remove unecessary copy --- comms/torchcomms/gloo/TorchCommGloo.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/comms/torchcomms/gloo/TorchCommGloo.cpp b/comms/torchcomms/gloo/TorchCommGloo.cpp index ba5d1a40..6df7cae3 100644 --- a/comms/torchcomms/gloo/TorchCommGloo.cpp +++ b/comms/torchcomms/gloo/TorchCommGloo.cpp @@ -721,9 +721,11 @@ c10::intrusive_ptr TorchCommGloo::all_gather( tensorListCPU[i].copy_(chunk); } - // Copy results back to original tensors (works for both CPU and CUDA) + // Copy results back to original device if needed for (size_t i = 0; i < tensorListCPU.size(); ++i) { - tensor_list[i].copy_(tensorListCPU[i]); + if (tensorListCPU[i].device() != tensor_list[i].device()) { + tensor_list[i].copy_(tensorListCPU[i]); + } } }, async_op); @@ -1312,9 +1314,11 @@ c10::intrusive_ptr TorchCommGloo::gather( outputListCPU[i].copy_(chunk); } - // Copy results back to original tensors (works for both CPU and CUDA) + // Copy results back to original device if needed for (size_t i = 0; i < outputListCPU.size(); ++i) { - output_tensor_list[i].copy_(outputListCPU[i]); + if (outputListCPU[i].device() != output_tensor_list[i].device()) { + output_tensor_list[i].copy_(outputListCPU[i]); + } } } }, From 2f17d1484e44846e978f24e65d96ebe1266c3fba Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Tue, 18 Nov 2025 14:42:02 +0530 Subject: [PATCH 4/4] lint fix --- comms/torchcomms/gloo/TorchCommGloo.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comms/torchcomms/gloo/TorchCommGloo.cpp b/comms/torchcomms/gloo/TorchCommGloo.cpp index 6df7cae3..896d829e 100644 --- a/comms/torchcomms/gloo/TorchCommGloo.cpp +++ b/comms/torchcomms/gloo/TorchCommGloo.cpp @@ -1024,8 +1024,8 @@ c10::intrusive_ptr TorchCommGloo::all_to_all_v_single( outputElements.reserve(output_split_sizes.size()); // Calculate number of elements in each dim 0 chunk. - auto dim0Numel = - inputCPU.numel() / std::max(inputCPU.size(0), static_cast(1)); + auto dim0Numel = inputCPU.numel() / + std::max(inputCPU.size(0), static_cast(1)); for (auto size : input_split_sizes) { inputElements.push_back(static_cast(size) * dim0Numel); }