From e13ee59ced9cec152d90d3666094091dec20f3e4 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 3 Nov 2025 14:00:58 +0800 Subject: [PATCH 01/10] Happy Init --- CMakeLists.txt | 5 + comms/torchcomms/device/XpuApi.cpp | 361 ++++ comms/torchcomms/device/XpuApi.hpp | 212 +++ comms/torchcomms/xccl/CMakeLists.txt | 45 + comms/torchcomms/xccl/TorchCommXCCL.cpp | 1454 +++++++++++++++++ comms/torchcomms/xccl/TorchCommXCCL.hpp | 374 +++++ .../xccl/TorchCommXCCLBootstrap.cpp | 307 ++++ .../xccl/TorchCommXCCLBootstrap.hpp | 85 + comms/torchcomms/xccl/TorchCommXCCLPy.cpp | 18 + comms/torchcomms/xccl/TorchCommXCCLUtils.cpp | 458 ++++++ comms/torchcomms/xccl/TorchWorkXCCL.cpp | 142 ++ comms/torchcomms/xccl/TorchWorkXCCL.hpp | 108 ++ comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp | 100 ++ comms/torchcomms/xccl/XcclApi.cpp | 197 +++ comms/torchcomms/xccl/XcclApi.hpp | 289 ++++ setup.py | 7 + 16 files changed, 4162 insertions(+) create mode 100644 comms/torchcomms/device/XpuApi.cpp create mode 100644 comms/torchcomms/device/XpuApi.hpp create mode 100644 comms/torchcomms/xccl/CMakeLists.txt create mode 100644 comms/torchcomms/xccl/TorchCommXCCL.cpp create mode 100644 comms/torchcomms/xccl/TorchCommXCCL.hpp create mode 100644 comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp create mode 100644 comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp create mode 100644 comms/torchcomms/xccl/TorchCommXCCLPy.cpp create mode 100644 comms/torchcomms/xccl/TorchCommXCCLUtils.cpp create mode 100644 comms/torchcomms/xccl/TorchWorkXCCL.cpp create mode 100644 comms/torchcomms/xccl/TorchWorkXCCL.hpp create mode 100644 comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp create mode 100644 comms/torchcomms/xccl/XcclApi.cpp create mode 100644 comms/torchcomms/xccl/XcclApi.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 458663d7..1f12faa2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,10 +10,12 @@ option(USE_NCCL "Whether to build NCCL or not" ON) option(USE_NCCLX "Whether to build NCCLX or not" ON) option(USE_GLOO "Whether to build Gloo or not" ON) option(USE_RCCL "Whether to build RCCL or not" OFF) +option(USE_XCCL "Whether to build XCCL or not" OFF) message(STATUS " USE_NCCL : ${USE_NCCL}") message(STATUS " USE_NCCLX : ${USE_NCCLX}") message(STATUS " USE_GLOO : ${USE_GLOO}") message(STATUS " USE_RCCL : ${USE_RCCL}") +message(STATUS " USE_XCCL : ${USE_XCCL}") # Find Python and PyTorch find_package(Python3 COMPONENTS Interpreter Development.Module REQUIRED) find_package(Torch REQUIRED) @@ -108,6 +110,9 @@ endif() if (USE_RCCL) include(comms/torchcomms/rccl/CMakeLists.txt) endif() +if (USE_XCCL) + include(comms/torchcomms/xccl/CMakeLists.txt) +endif() # Install targets to Python package structure install(TARGETS torchcomms diff --git a/comms/torchcomms/device/XpuApi.cpp b/comms/torchcomms/device/XpuApi.cpp new file mode 100644 index 00000000..d6d594cd --- /dev/null +++ b/comms/torchcomms/device/XpuApi.cpp @@ -0,0 +1,361 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/torchcomms/device/XpuApi.hpp" +#include +#include +#include +#include +#include + +namespace torch { +namespace comms { + +// ============================================================================ +// Device Management Implementation +// ============================================================================ + +xpu_result_t DefaultXpuApi::setDevice(int device) { + try { + ::c10::xpu::set_device(device); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::getDeviceProperties(xpuDeviceProp* prop, int device) { + if (!prop) { + return XPU_ERROR_INVALID_VALUE; + } + + try { + ::c10::xpu::DeviceProp device_prop; + ::c10::xpu::get_device_properties(&device_prop, device); + + // Map c10::xpu::DeviceProp to xpuDeviceProp + strncpy(prop->name, device_prop.name.c_str(), 255); + prop->name[255] = '\0'; + prop->totalGlobalMem = device_prop.global_mem_size; + + // Extract major/minor version from device_id + prop->major = (device_prop.device_id >> 16) & 0xFFFF; + prop->minor = device_prop.device_id & 0xFFFF; + + // Get SYCL device for additional properties + sycl::device& sycl_device = ::c10::xpu::get_raw_device(device); + auto max_work_group_size = sycl_device.get_info(); + auto max_work_item_sizes = sycl_device.get_info>(); + auto max_compute_units = sycl_device.get_info(); + + prop->multiProcessorCount = max_compute_units; + prop->maxThreadsPerBlock = max_work_group_size; + prop->maxThreadsDim[0] = max_work_item_sizes[0]; + prop->maxThreadsDim[1] = max_work_item_sizes[1]; + prop->maxThreadsDim[2] = max_work_item_sizes[2]; + + // Max grid size - not directly available in SYCL, use reasonable defaults + prop->maxGridSize[0] = 2147483647; + prop->maxGridSize[1] = 65535; + prop->maxGridSize[2] = 65535; + + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::memGetInfo(size_t* free, size_t* total) { + if (!free || !total) { + return XPU_ERROR_INVALID_VALUE; + } + + try { + int device = ::c10::xpu::current_device(); + sycl::device& sycl_device = ::c10::xpu::get_raw_device(device); + + *total = sycl_device.get_info(); + *free = *total; // SYCL doesn't provide free memory query + + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::getDeviceCount(int* count) { + if (!count) { + return XPU_ERROR_INVALID_VALUE; + } + + try { + *count = ::c10::xpu::device_count(); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +// ============================================================================ +// Stream Management Implementation +// ============================================================================ + +xpu_result_t DefaultXpuApi::streamCreateWithPriority( + xpuStream_t& stream, + unsigned int flags, + int priority) { + try { + // Map priority: priority < 0 = high, priority >= 0 = normal + bool isHighPriority = (priority < 0); + stream = ::c10::xpu::getStreamFromPool(isHighPriority); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::streamDestroy(const xpuStream_t& stream) { + // Stream is managed by PyTorch, nothing to do + return XPU_SUCCESS; +} + +xpu_result_t DefaultXpuApi::streamWaitEvent( + const xpuStream_t& stream, + xpuEvent_t& event, + unsigned int flags) { + try { + event.block(stream); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_HANDLE; + } +} + +xpuStream_t DefaultXpuApi::getCurrentXPUStream(int device_index) { + return ::c10::xpu::getCurrentXPUStream(device_index); +} + +xpu_result_t DefaultXpuApi::streamSynchronize(const xpuStream_t& stream) { + try { + stream.queue().wait_and_throw(); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_HANDLE; + } +} + +xpu_result_t DefaultXpuApi::streamIsCapturing( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus) { + if (!pCaptureStatus) { + return XPU_ERROR_INVALID_VALUE; + } + + // XPU/SYCL doesn't support stream capture + *pCaptureStatus = xpuStreamCaptureStatusNone; + return XPU_SUCCESS; +} + +xpu_result_t DefaultXpuApi::streamGetCaptureInfo( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus, + unsigned long long* pId) { + if (!pCaptureStatus) { + return XPU_ERROR_INVALID_VALUE; + } + + *pCaptureStatus = xpuStreamCaptureStatusNone; + if (pId) { + *pId = 0; + } + return XPU_SUCCESS; +} + +// ============================================================================ +// Memory Management Implementation +// ============================================================================ + +xpu_result_t DefaultXpuApi::malloc(void** devPtr, size_t size) { + if (!devPtr) { + return XPU_ERROR_INVALID_VALUE; + } + + if (size == 0) { + *devPtr = nullptr; + return XPU_SUCCESS; + } + + try { + // Use SYCL's malloc_device + sycl::context& ctx = ::c10::xpu::get_device_context(); + int device = ::c10::xpu::current_device(); + sycl::device& dev = ::c10::xpu::get_raw_device(device); + + *devPtr = sycl::malloc_device(size, dev, ctx); + + if (!*devPtr) { + return XPU_ERROR_OUT_OF_MEMORY; + } + + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_OUT_OF_MEMORY; + } +} + +xpu_result_t DefaultXpuApi::free(void* devPtr) { + if (!devPtr) { + return XPU_SUCCESS; + } + + try { + sycl::context& ctx = ::c10::xpu::get_device_context(); + sycl::free(devPtr, ctx); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::memcpyAsync( + void* dst, + const void* src, + size_t count, + const xpuStream_t& stream) { + if (!dst || !src) { + return XPU_ERROR_INVALID_VALUE; + } + + if (count == 0) { + return XPU_SUCCESS; + } + + try { + stream.queue().memcpy(dst, src, count); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +// ============================================================================ +// Event Management Implementation +// ============================================================================ + +xpu_result_t DefaultXpuApi::eventCreate(xpuEvent_t& event) { + try { + event = ::at::xpu::XPUEvent(false); // No timing + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::eventCreateWithFlags( + xpuEvent_t& event, + unsigned int flags) { + try { + bool enable_timing = (flags & 0x1) != 0; + event = ::at::xpu::XPUEvent(enable_timing); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::eventDestroy(const xpuEvent_t& event) { + // Event is RAII, nothing to do + return XPU_SUCCESS; +} + +xpu_result_t DefaultXpuApi::eventRecord(xpuEvent_t& event, const xpuStream_t& stream) { + try { + event.record(stream); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_HANDLE; + } +} + +xpu_result_t DefaultXpuApi::eventQuery(const xpuEvent_t& event) { + try { + bool is_complete = event.query(); + return is_complete ? XPU_SUCCESS : XPU_ERROR_NOT_READY; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_HANDLE; + } +} + +// ============================================================================ +// Graph Operations (Unsupported) +// ============================================================================ + +xpu_result_t DefaultXpuApi::userObjectCreate( + xpuUserObject_t* object_out, + void* ptr, + xpuHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags) { + // XPU/SYCL doesn't support user objects + return XPU_ERROR_UNSUPPORTED; +} + +xpu_result_t DefaultXpuApi::graphRetainUserObject( + xpuGraph_t graph, + xpuUserObject_t object, + unsigned int count, + unsigned int flags) { + // XPU/SYCL doesn't support graphs + return XPU_ERROR_UNSUPPORTED; +} + +xpu_result_t DefaultXpuApi::streamGetCaptureInfo_v2( + const xpuStream_t& stream, + xpuStreamCaptureStatus* captureStatus_out, + unsigned long long* id_out, + xpuGraph_t* graph_out, + const xpuGraphNode_t** dependencies_out, + size_t* numDependencies_out) { + if (captureStatus_out) { + *captureStatus_out = xpuStreamCaptureStatusNone; + } + if (id_out) { + *id_out = 0; + } + if (graph_out) { + *graph_out = nullptr; + } + if (dependencies_out) { + *dependencies_out = nullptr; + } + if (numDependencies_out) { + *numDependencies_out = 0; + } + + return XPU_SUCCESS; +} + +// ============================================================================ +// Error Handling +// ============================================================================ + +const char* DefaultXpuApi::getErrorString(xpu_result_t error) { + switch (error) { + case XPU_SUCCESS: + return "success"; + case XPU_ERROR_INVALID_VALUE: + return "invalid value"; + case XPU_ERROR_NOT_READY: + return "not ready"; + case XPU_ERROR_INVALID_HANDLE: + return "invalid handle"; + case XPU_ERROR_OUT_OF_MEMORY: + return "out of memory"; + case XPU_ERROR_UNSUPPORTED: + return "unsupported feature"; + default: + return "unknown error"; + } +} + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/device/XpuApi.hpp b/comms/torchcomms/device/XpuApi.hpp new file mode 100644 index 00000000..f6dee9c7 --- /dev/null +++ b/comms/torchcomms/device/XpuApi.hpp @@ -0,0 +1,212 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { +namespace comms { + +// Use PyTorch XPU types directly as value types +using xpuStream_t = ::c10::xpu::XPUStream; +using xpuEvent_t = ::at::xpu::XPUEvent; + +// Device properties structure - mapped from SYCL +struct xpuDeviceProp { + char name[256]; + size_t totalGlobalMem; + int major; + int minor; + int multiProcessorCount; + int maxThreadsPerBlock; + int maxThreadsDim[3]; + int maxGridSize[3]; +}; + +// Graph-related types (placeholder - unsupported in XPU) +using xpuGraph_t = void*; +using xpuGraphNode_t = void*; +using xpuUserObject_t = void*; +using xpuHostFn_t = void(*)(void*); + +// Stream capture status (not supported in XPU) +enum xpuStreamCaptureStatus { + xpuStreamCaptureStatusNone = 0, +}; + +// Error code type +using xpu_result_t = int32_t; +constexpr xpu_result_t XPU_SUCCESS = 0; +constexpr xpu_result_t XPU_ERROR_INVALID_VALUE = 1; +constexpr xpu_result_t XPU_ERROR_NOT_READY = 2; +constexpr xpu_result_t XPU_ERROR_INVALID_HANDLE = 3; +constexpr xpu_result_t XPU_ERROR_OUT_OF_MEMORY = 4; +constexpr xpu_result_t XPU_ERROR_UNSUPPORTED = 5; + +#define XPU_CHECK(xpu_api, call, err_str) \ + do { \ + xpu_result_t status = call; \ + if (status != XPU_SUCCESS) { \ + std::stringstream ss; \ + ss << err_str << ": " << xpu_api->getErrorString(status) << " at " \ + << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(ss.str()); \ + } \ + } while (0) + +/** + * Abstract interface for XPU API operations. + * This allows for dependency injection and testing by providing + * a way to override XPU API calls. + */ +class XpuApi { + public: + virtual ~XpuApi() = default; + + // Device management + virtual xpu_result_t setDevice(int device) = 0; + virtual xpu_result_t getDeviceProperties(xpuDeviceProp* prop, int device) = 0; + virtual xpu_result_t memGetInfo(size_t* free, size_t* total) = 0; + virtual xpu_result_t getDeviceCount(int* count) = 0; + + // Stream management + virtual xpu_result_t streamCreateWithPriority( + xpuStream_t& stream, + unsigned int flags, + int priority) = 0; + virtual xpu_result_t streamDestroy(const xpuStream_t& stream) = 0; + virtual xpu_result_t streamWaitEvent( + const xpuStream_t& stream, + xpuEvent_t& event, + unsigned int flags) = 0; + virtual xpuStream_t getCurrentXPUStream(int device_index) = 0; + virtual xpu_result_t streamSynchronize(const xpuStream_t& stream) = 0; + virtual xpu_result_t streamIsCapturing( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus) = 0; + virtual xpu_result_t streamGetCaptureInfo( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus, + unsigned long long* pId) = 0; + + // Memory management + virtual xpu_result_t malloc(void** devPtr, size_t size) = 0; + virtual xpu_result_t free(void* devPtr) = 0; + virtual xpu_result_t memcpyAsync( + void* dst, + const void* src, + size_t count, + const xpuStream_t& stream) = 0; + + // Event management + virtual xpu_result_t eventCreate(xpuEvent_t& event) = 0; + virtual xpu_result_t eventCreateWithFlags( + xpuEvent_t& event, + unsigned int flags) = 0; + virtual xpu_result_t eventDestroy(const xpuEvent_t& event) = 0; + virtual xpu_result_t eventRecord(xpuEvent_t& event, const xpuStream_t& stream) = 0; + virtual xpu_result_t eventQuery(const xpuEvent_t& event) = 0; + + // Graph operations (unsupported, kept for API compatibility) + virtual xpu_result_t userObjectCreate( + xpuUserObject_t* object_out, + void* ptr, + xpuHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags) = 0; + virtual xpu_result_t graphRetainUserObject( + xpuGraph_t graph, + xpuUserObject_t object, + unsigned int count, + unsigned int flags) = 0; + virtual xpu_result_t streamGetCaptureInfo_v2( + const xpuStream_t& stream, + xpuStreamCaptureStatus* captureStatus_out, + unsigned long long* id_out, + xpuGraph_t* graph_out, + const xpuGraphNode_t** dependencies_out, + size_t* numDependencies_out) = 0; + + // Error handling + virtual const char* getErrorString(xpu_result_t error) = 0; +}; + +/** + * Default implementation that uses SYCL and PyTorch XPU APIs. + */ +class DefaultXpuApi : public XpuApi { + public: + ~DefaultXpuApi() override = default; + + // Device management + xpu_result_t setDevice(int device) override; + xpu_result_t getDeviceProperties(xpuDeviceProp* prop, int device) override; + xpu_result_t memGetInfo(size_t* free, size_t* total) override; + xpu_result_t getDeviceCount(int* count) override; + + // Stream management + xpu_result_t streamCreateWithPriority( + xpuStream_t& stream, + unsigned int flags, + int priority) override; + xpu_result_t streamDestroy(const xpuStream_t& stream) override; + xpu_result_t streamWaitEvent( + const xpuStream_t& stream, + xpuEvent_t& event, + unsigned int flags) override; + xpuStream_t getCurrentXPUStream(int device_index) override; + xpu_result_t streamSynchronize(const xpuStream_t& stream) override; + xpu_result_t streamIsCapturing( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus) override; + xpu_result_t streamGetCaptureInfo( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus, + unsigned long long* pId) override; + + // Memory management + xpu_result_t malloc(void** devPtr, size_t size) override; + xpu_result_t free(void* devPtr) override; + xpu_result_t memcpyAsync( + void* dst, + const void* src, + size_t count, + const xpuStream_t& stream) override; + + // Event management + xpu_result_t eventCreate(xpuEvent_t& event) override; + xpu_result_t eventCreateWithFlags(xpuEvent_t& event, unsigned int flags) override; + xpu_result_t eventDestroy(const xpuEvent_t& event) override; + xpu_result_t eventRecord(xpuEvent_t& event, const xpuStream_t& stream) override; + xpu_result_t eventQuery(const xpuEvent_t& event) override; + + // Graph operations (unsupported) + xpu_result_t userObjectCreate( + xpuUserObject_t* object_out, + void* ptr, + xpuHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags) override; + xpu_result_t graphRetainUserObject( + xpuGraph_t graph, + xpuUserObject_t object, + unsigned int count, + unsigned int flags) override; + xpu_result_t streamGetCaptureInfo_v2( + const xpuStream_t& stream, + xpuStreamCaptureStatus* captureStatus_out, + unsigned long long* id_out, + xpuGraph_t* graph_out, + const xpuGraphNode_t** dependencies_out, + size_t* numDependencies_out) override; + + // Error handling + const char* getErrorString(xpu_result_t error) override; +}; + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/CMakeLists.txt b/comms/torchcomms/xccl/CMakeLists.txt new file mode 100644 index 00000000..dc69015f --- /dev/null +++ b/comms/torchcomms/xccl/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Extension: torchcomms._comms_xccl +file(GLOB TORCHCOMMS_XCCL_SOURCES "comms/torchcomms/xccl/*.cpp") +file(GLOB TORCHCOMMS_XPU_API_SOURCE "comms/torchcomms/device/XpuApi.cpp") +# find_package(XPU) + +set(XCCL_INCLUDE "$ENV{CCL_ROOT}/include") +set(XCCL_SHARED_LIB "$ENV{CCL_ROOT}/lib/libccl.so.2") + +include(FindPackageHandleStandardArgs) + + +add_library(torchcomms_comms_xccl MODULE + ${TORCHCOMMS_XCCL_SOURCES} + ${TORCHCOMMS_XPU_API_SOURCE} +) +set_target_properties(torchcomms_comms_xccl PROPERTIES + PREFIX "" + OUTPUT_NAME "_comms_xccl" + SUFFIX ".${Python3_SOABI}.so" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/comms/torchcomms" + BUILD_RPATH "$ORIGIN" + INSTALL_RPATH "$ORIGIN" +) +target_include_directories(torchcomms_comms_xccl PRIVATE + ${ROOT} + ${XCCL_INCLUDE} + ${CONDA_INCLUDE} + ${Python3_INCLUDE_DIRS} +) +target_compile_features(torchcomms_comms_xccl PRIVATE cxx_std_20) +target_link_directories(torchcomms_comms_xccl PRIVATE ${CONDA_LIB}) +target_link_libraries(torchcomms_comms_xccl PRIVATE + ${TORCH_LIBRARIES} + ${TORCH_PYTHON_LIB} + torchcomms +) + +target_link_libraries(torchcomms_comms_xccl PRIVATE + ${XCCL_SHARED_LIB} +) + +install(TARGETS torchcomms_comms_xccl + LIBRARY DESTINATION . +) diff --git a/comms/torchcomms/xccl/TorchCommXCCL.cpp b/comms/torchcomms/xccl/TorchCommXCCL.cpp new file mode 100644 index 00000000..90d375e4 --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCL.cpp @@ -0,0 +1,1454 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" + +#include +#include +#include +#include +#include "comms/torchcomms/TorchCommFactory.hpp" +#include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp" +// #include "xccl.h" // @manual + +namespace torch { +namespace comms { + +onecclResult_t XCCLException::getResult() const { + return result_; +} + +TorchCommXCCL::TorchCommXCCL() + : xccl_comm_{nullptr}, + device_(at::kXPU), + init_state_(InitializationState::UNINITIALIZED), + shutdown_(false) {} + +TorchCommXCCL::TorchCommXCCL(const onecclComm_t xccl_comm) + : xccl_comm_(xccl_comm), + device_(at::kXPU), + init_state_(InitializationState::UNINITIALIZED), + shutdown_(false) {} + +TorchCommXCCL::~TorchCommXCCL() { + if (init_state_ == InitializationState::INITIALIZED) { + TC_LOG(ERROR) << "TorchCommXCCL was not finalized before destruction"; + } + + // We need to dteach the memory hook in case finalize is not called, + // so that we don't encounter a memory corruption. + detachMemoryHook(); +} + +void TorchCommXCCL::init( + at::Device device, + const std::string& name, + const CommOptions& options) { + // Initialize private members + device_ = device; + name_ = name; + options_ = options; + + // Only initialize once + if (init_state_ == InitializationState::INITIALIZED) { + throw std::runtime_error("TorchCommXCCL already initialized"); + } else if (init_state_ == InitializationState::FINALIZED) { + throw std::runtime_error("TorchCommXCCL already finalized"); + } + init_state_ = InitializationState::INITIALIZED; + + // Initialize default XCCL API implementation if not already set + if (!xccl_api_) { + xccl_api_ = std::make_unique(); + } + + // Initialize default XPU API implementation if not already set + if (!xpu_api_) { + xpu_api_ = std::make_unique(); + } + + if (device_.index() == -1 || xccl_comm_ == nullptr) { + TC_LOG(INFO) << "[TC] Creating bootstrap..."; + auto bootstrap = new TorchCommXCCLBootstrap( + options_.store, device_, xccl_api_, xpu_api_, options_.timeout); + device_ = bootstrap->getDevice(); + + if (xccl_comm_ == nullptr) { + TC_LOG(INFO) << "[TC] Creating XCCL communicator..."; + xccl_comm_ = bootstrap->createXcclComm(name, options); + TC_LOG(INFO) << "[TC] XCCL communicator created"; + } + + delete bootstrap; + TC_LOG(INFO) << "[TC] Bootstrap deleted"; + } + + // Set XPU device and verify it's accessible + TC_LOG(INFO) << "[TC] Setting XPU device " << device_.index(); + XPU_CHECK( + xpu_api_, + xpu_api_->setDevice(device_.index()), + "Failed to set XPU device to " + std::to_string(device_.index())); + + // Verify device properties and memory availability + TC_LOG(INFO) << "[TC] Getting device properties..."; + xpuDeviceProp device_prop = {}; + XPU_CHECK( + xpu_api_, + xpu_api_->getDeviceProperties(&device_prop, device_.index()), + "Failed to get device properties for device " + + std::to_string(device_.index())); + + // Check available memory + TC_LOG(INFO) << "[TC] Getting memory info..."; + size_t free_memory, total_memory; + XPU_CHECK( + xpu_api_, + xpu_api_->memGetInfo(&free_memory, &total_memory), + "Failed to get memory info for device " + + std::to_string(device_.index())); + + // Read hints and store them + for (auto const& [key, val] : options_.hints) { + if (key.starts_with("torchcomm::xccl::")) { + if (key == "torchcomm::xccl::high_priority_stream") { + high_priority_stream_ = string_to_bool(val); + } else { + throw std::runtime_error("Unrecognized hint " + key); + } + } else { + // Ignore keys that do not start with "torchcomm::xccl::" + } + } + + // Create internal stream + int stream_priority = 0; + + // Check for high priority stream hint + if (high_priority_stream_) { + stream_priority = -1; + } + + // Initialize internal stream + // Get a temporary stream first (to have a valid object), then replace it + TC_LOG(INFO) << "[TC] Creating internal XPU stream with priority " << stream_priority; + xpuStream_t temp_stream = xpu_api_->getCurrentXPUStream(device_.index()); + XPU_CHECK( + xpu_api_, + xpu_api_->streamCreateWithPriority( + temp_stream, /*flags=*/0, stream_priority), + "Failed to create internal XPU stream on device " + + std::to_string(device_.index())); + internal_stream_ = std::move(temp_stream); + TC_LOG(INFO) << "[TC] Internal stream created"; + + // Create dependency event for stream synchronization + TC_LOG(INFO) << "[TC] Creating dependency event"; + xpuEvent_t temp_event(/*enable_timing=*/false); + XPU_CHECK( + xpu_api_, + xpu_api_->eventCreateWithFlags( + temp_event, /*flags=*/0), + "Failed to create dependency event on device " + + std::to_string(device_.index())); + TC_LOG(INFO) << "[TC] Dependency event created"; + dependency_event_ = std::move(temp_event); + + // Allocate XPU buffer for barrier operations + TC_LOG(INFO) << "[TC] Allocating barrier buffer"; + XPU_CHECK( + xpu_api_, + xpu_api_->malloc(&barrier_buffer_, sizeof(float)), + "Failed to allocate barrier buffer"); + TC_LOG(INFO) << "[TC] Barrier buffer allocated"; + + if (options_.hints.contains("torchcomm::xccl::max_event_pool_size")) { + max_event_pool_size_ = + std::stoull(options_.hints.at("torchcomm::xccl::max_event_pool_size")); + } else { + max_event_pool_size_ = kMaxEventPoolSize; + } + + // Give up our internal reference to the store object here. The caller + // would still need to keep a reference to the store object till the init + // call returns, at which point the XCCL communicator would already be + // created. + if (options_.store) { + options_.store.reset(); + } + + TC_LOG(INFO) << "[TC] Getting XCCL rank"; + onecclResult_t xcclErr; + xcclErr = xccl_api_->commUserRank(xccl_comm_, &rank_); + if (xcclErr != onecclSuccess) { + throw std::runtime_error("XCCL User Rank failed"); + } + TC_LOG(INFO) << "[TC] Rank: " << rank_; + + tryTorchCommLoggingInit("torchcomm"); + + TC_LOG(INFO) << "[TC] Getting XCCL comm size"; + xcclErr = xccl_api_->commCount(xccl_comm_, &comm_size_); + if (xcclErr != onecclSuccess) { + throw std::runtime_error("XCCL Count failed"); + } + TC_LOG(INFO) << "[TC] Comm size: " << comm_size_; + + TC_LOG(INFO) << "[TC] Creating tracing object"; + tracing_ = std::make_shared(name, comm_size_, rank_); + tracing_->recordEvent("init"); + + // Start timeout watchdog thread + TC_LOG(INFO) << "[TC] Starting timeout watchdog thread"; + timeout_thread_ = std::thread(&TorchCommXCCL::timeoutWatchdog, this); + + // Register comm with CachingAllocator + TC_LOG(INFO) << "[TC] Attaching memory hook"; + attachMemoryHook(); + TC_LOG(INFO) << "[TC] TorchCommXCCL initialization completed"; +} + +void TorchCommXCCL::finalize() { + if (init_state_ == InitializationState::UNINITIALIZED) { + throw std::runtime_error("TorchCommXCCL not initialized"); + } else if (init_state_ == InitializationState::FINALIZED) { + throw std::runtime_error("TorchCommXCCL already finalized"); + } + init_state_ = InitializationState::FINALIZED; + + // Signal shutdown to timeout watchdog + shutdown_ = true; + + // Wake up the timeout watchdog thread + { + std::lock_guard lock(timeout_mutex_); + timeout_cv_.notify_all(); + } + + // Wait for timeout thread to finish + if (timeout_thread_.joinable()) { + timeout_thread_.join(); + } + + // Wait for all pending work objects to complete and get final status + auto work_status = workq_.finalize(); + + if (work_status == TorchWorkXCCL::WorkStatus::NOT_STARTED || + work_status == TorchWorkXCCL::WorkStatus::INPROGRESS) { + throw std::runtime_error( + "WorkQ finalize returned in progress or not started state"); + } + + // Update comm_state_ based on the work status + if (work_status == TorchWorkXCCL::WorkStatus::TIMEDOUT) { + comm_state_ = CommState::TIMEOUT; + abortXcclComm(); + throw std::runtime_error("Work timed out during finalize"); + } else if (work_status == TorchWorkXCCL::WorkStatus::ERROR) { + comm_state_ = CommState::ERROR; + onecclResult_t asyncErr; + xccl_api_->commGetAsyncError(xccl_comm_, &asyncErr); + XCCLException xcclException(*xccl_api_, "XCCL Async Error", asyncErr); + abortXcclComm(); + throw xcclException; + } + + // Clean up event pool + { + std::lock_guard lock(event_pool_mutex_); + while (!event_pool_.empty()) { + xpuEvent_t event = std::move(event_pool_.front()); + event_pool_.pop(); + XPU_CHECK( + xpu_api_, xpu_api_->eventDestroy(event), "Failed to destroy event"); + } + } + + // Free barrier buffer. TODO: handle errors on xpu free and stream destroy + if (barrier_buffer_) { + XPU_CHECK( + xpu_api_, + xpu_api_->free(barrier_buffer_), + "Failed to free barrier buffer"); + barrier_buffer_ = nullptr; + } + + // Destroy dependency event + if (dependency_event_.has_value()) { + XPU_CHECK( + xpu_api_, + xpu_api_->eventDestroy(dependency_event_.value()), + "Failed to destroy dependency event"); + dependency_event_.reset(); + } + + // Destroy internal stream + if (internal_stream_.has_value()) { + XPU_CHECK( + xpu_api_, + xpu_api_->streamDestroy(internal_stream_.value()), + "Failed to destroy internal stream"); + internal_stream_.reset(); + } + + // Destroy XCCL communicator + // TODO: should probably not call this after calling abort. + if (xccl_comm_) { + detachMemoryHook(); + // Deregister comm from the CachingAllocator + xccl_api_->commDestroy(xccl_comm_); + xccl_comm_ = nullptr; + } +} + +void TorchCommXCCL::abortXcclComm() { + detachMemoryHook(); + if (xccl_comm_) { + xccl_api_->commAbort(xccl_comm_); + xccl_comm_ = nullptr; + } + if (options_.abort_process_on_timeout_or_error) { + TC_LOG(ERROR) << "Aborting process due to timeout"; + abort(); + } +} + +int TorchCommXCCL::getRank() const { + checkInitialized(); + + int rank; + onecclResult_t xcclErr = xccl_api_->commUserRank(xccl_comm_, &rank); + if (xcclErr != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL User Rank failed", xcclErr); + } + return rank; +} + +int TorchCommXCCL::getSize() const { + checkInitialized(); + + int comm_size; + onecclResult_t xcclErr = xccl_api_->commCount(xccl_comm_, &comm_size); + if (xcclErr != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL Count failed", xcclErr); + } + return comm_size; +} + +std::string_view TorchCommXCCL::getBackendName() const { + return kBackendName; +} + +std::string_view TorchCommXCCL::getCommName() const { + return name_; +} + +static inline std::chrono::milliseconds getOperationTimeout( + std::chrono::milliseconds timeout, + std::chrono::milliseconds default_timeout) { + // If timeout is kNoTimeout (0ms), use the default timeout from options + if (timeout == kNoTimeout) { + return default_timeout; + } + return timeout; +} + +// Point-to-Point Operations +std::shared_ptr TorchCommXCCL::send( + const at::Tensor& tensor, + int dst, + bool async_op, + const SendOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(tensor); + + tracing_->recordEventWithInputOutput("send", dst, {tensor}, {tensor}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); + + // Record start event before XCCL operation + work->recordStart(); + + onecclResult_t result = xccl_api_->send( + tensor.data_ptr(), + tensor.numel(), + getXcclDataType(tensor), + dst, + xccl_comm_, + stream); + + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL Send failed", result); + } + + // Record end event after XCCL operation + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::recv( + at::Tensor& tensor, + int src, + bool async_op, + const RecvOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(tensor); + + tracing_->recordEventWithInputOutput("recv", src, {tensor}, {tensor}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {}); + + // Record start event before XCCL operation + work->recordStart(); + + onecclResult_t result = xccl_api_->recv( + tensor.data_ptr(), + tensor.numel(), + getXcclDataType(tensor), + src, + xccl_comm_, + stream); + + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL Recv failed", result); + } + + // Record end event after XCCL operation + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +// Batch P2P Operations +std::shared_ptr TorchCommXCCL::batch_op_issue( + const std::vector& ops, + bool async_op, + const BatchP2POptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + + if (ops.empty()) { + throw std::runtime_error("Cannot issue empty batch operation"); + } + + // Collect input and output tensors for work tracking + std::vector input_tensors; + std::vector output_tensors; + + for (const auto& op : ops) { + if (op.type == BatchSendRecv::P2POp::OpType::SEND) { + at::Tensor tensor = op.tensor; + ensureTensorContiguous(tensor); + input_tensors.push_back(tensor); + } else if (op.type == BatchSendRecv::P2POp::OpType::RECV) { + at::Tensor tensor = op.tensor; + ensureTensorContiguous(tensor); + output_tensors.push_back(tensor); + } else { + throw std::runtime_error("Unknown op type"); + } + } + + tracing_->recordEventWithInputOutput( + "batch_op_issue", rank_, input_tensors, output_tensors); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, + getOperationTimeout(options.timeout, options_.timeout), + input_tensors); + + // Record start event before XCCL operations + work->recordStart(); + + // Start XCCL group for batched operations + onecclResult_t result = xccl_api_->groupStart(); + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL GroupStart failed", result); + } + + // Issue each operation individually + for (const auto& op : ops) { + if (op.type == BatchSendRecv::P2POp::OpType::SEND) { + result = xccl_api_->send( + op.tensor.data_ptr(), + op.tensor.numel(), + getXcclDataType(op.tensor), + op.peer, + xccl_comm_, + stream); + + if (result != onecclSuccess) { + xccl_api_->groupEnd(); // Clean up group on error + throw XCCLException( + *xccl_api_, "XCCL Send failed in batch operation", result); + } + } else if (op.type == BatchSendRecv::P2POp::OpType::RECV) { + result = xccl_api_->recv( + op.tensor.data_ptr(), + op.tensor.numel(), + getXcclDataType(op.tensor), + op.peer, + xccl_comm_, + stream); + + if (result != onecclSuccess) { + xccl_api_->groupEnd(); // Clean up group on error + throw XCCLException( + *xccl_api_, "XCCL Recv failed in batch operation", result); + } + } + } + + // End XCCL group + result = xccl_api_->groupEnd(); + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL GroupEnd failed", result); + } + + // Record end event after XCCL operations + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +// Collective Operations +std::shared_ptr TorchCommXCCL::broadcast( + at::Tensor& tensor, + int root, + bool async_op, + const BroadcastOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(tensor); + + tracing_->recordEventWithInputOutput("broadcast", root, {tensor}, {tensor}); + + xpuStream_t stream = getOperationStream(async_op); + + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); + + // Record start event before XCCL operation + work->recordStart(); + + onecclResult_t result = xccl_api_->bcast( + tensor.data_ptr(), + tensor.numel(), + getXcclDataType(tensor), + root, + xccl_comm_, + stream); + + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL Broadcast failed", result); + } + + // Record end event after XCCL operation + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::all_reduce( + at::Tensor& tensor, + ReduceOp op, + bool async_op, + const AllReduceOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(tensor); + + tracing_->recordEventWithInputOutput("all_reduce", rank_, {tensor}, {tensor}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); + + // Record start event before XCCL operation + work->recordStart(); + + const auto dataType = getXcclDataType(tensor); + onecclResult_t result = xccl_api_->allReduce( + tensor.data_ptr(), + tensor.data_ptr(), // In-place operation + tensor.numel(), + dataType, + getXcclReduceOp(op, xccl_comm_, dataType), + xccl_comm_, + stream); + + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL AllReduce failed", result); + } + + // Record end event after XCCL operation + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::reduce( + const at::Tensor& tensor, + int root, + ReduceOp op, + bool async_op, + const ReduceOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(tensor); + + tracing_->recordEventWithInputOutput("reduce", root, {tensor}, {tensor}); + + xpuStream_t stream = getOperationStream(async_op); + std::vector output_tensors; + if (rank_ == root) { + output_tensors.push_back(tensor); + } + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); + + // Record start event before XCCL operation + work->recordStart(); + + const auto dataType = getXcclDataType(tensor); + onecclResult_t result = xccl_api_->reduce( + tensor.data_ptr(), + rank_ == root ? tensor.data_ptr() : nullptr, + tensor.numel(), + dataType, + getXcclReduceOp(op, xccl_comm_, dataType), + root, + xccl_comm_, + stream); + + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL Reduce failed", result); + } + + // Record end event after XCCL operation + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::all_gather( + const std::vector& tensor_list, + const at::Tensor& tensor, + bool async_op, + const AllGatherOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + if (tensor_list.size() != static_cast(comm_size_)) { + throw std::runtime_error( + "tensor_list size must equal comm_size for all_gather"); + } + + // Ensure input tensor is contiguous + ensureTensorContiguous(tensor); + + // Check that all output tensors are contiguous and have correct size + for (const auto& t : tensor_list) { + ensureTensorContiguous(t); + if (t.numel() != tensor.numel()) { + throw std::runtime_error( + "All tensors in tensor_list must have same size as input tensor"); + } + } + + tracing_->recordEventWithInputOutput( + "all_gather", rank_, tensor_list, {tensor}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); + + work->recordStart(); + + // Use multiple broadcast operations for all_gather + xccl_api_->groupStart(); + + for (int i = 0; i < comm_size_; ++i) { + xccl_api_->broadcast( + tensor.data_ptr(), + tensor_list[i].data_ptr(), + tensor.numel(), + getXcclDataType(tensor_list[i]), + i, + xccl_comm_, + stream); + } + + xccl_api_->groupEnd(); + + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::all_gather_single( + at::Tensor& output, + const at::Tensor& input, + bool async_op, + const AllGatherSingleOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(output); + ensureTensorContiguous(input); + + if (output.numel() != input.numel() * comm_size_) { + throw std::runtime_error( + "Output tensor size must be input_size * comm_size for all_gather_single"); + } + + tracing_->recordEventWithInputOutput( + "all_gather_single", rank_, {input}, {output}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {input}); + + work->recordStart(); + + onecclResult_t result = xccl_api_->allGather( + input.data_ptr(), + output.data_ptr(), + input.numel(), + getXcclDataType(input), + xccl_comm_, + stream); + + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL AllGather failed", result); + } + + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::reduce_scatter( + at::Tensor& output, + const std::vector& input_list, + ReduceOp op, + bool async_op, + const ReduceScatterOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(output); + + if (input_list.size() != static_cast(comm_size_)) { + throw std::runtime_error( + "input_list size must equal comm_size for reduce_scatter"); + } + + // Check that all input tensors are contiguous and have correct size + for (const auto& t : input_list) { + ensureTensorContiguous(t); + if (t.numel() != output.numel()) { + throw std::runtime_error( + "All input tensors must have same size as output tensor"); + } + } + + tracing_->recordEventWithInputOutput( + "reduce_scatter", rank_, input_list, {output}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, + getOperationTimeout(options.timeout, options_.timeout), + input_list); + + work->recordStart(); + + // Use multiple reduce operations for reduce_scatter + xccl_api_->groupStart(); + + for (int i = 0; i < comm_size_; ++i) { + const auto dataType = getXcclDataType(input_list[i]); + if (i == rank_) { + // This rank receives the reduced result + xccl_api_->reduce( + input_list[i].data_ptr(), + output.data_ptr(), + output.numel(), + dataType, + getXcclReduceOp(op, xccl_comm_, dataType), + i, + xccl_comm_, + stream); + } else { + // Other ranks contribute to the reduction + xccl_api_->reduce( + input_list[i].data_ptr(), + nullptr, // Non-root ranks don't receive + input_list[i].numel(), + dataType, + getXcclReduceOp(op, xccl_comm_, dataType), + i, + xccl_comm_, + stream); + } + } + + xccl_api_->groupEnd(); + + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::reduce_scatter_single( + at::Tensor& output, + const at::Tensor& input, + ReduceOp op, + bool async_op, + const ReduceScatterSingleOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(output); + ensureTensorContiguous(input); + + if (input.numel() != output.numel() * comm_size_) { + throw std::runtime_error( + "Input tensor size must be output_size * comm_size for reduce_scatter_single"); + } + + tracing_->recordEventWithInputOutput( + "reduce_scatter_single", rank_, {input}, {output}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {input}); + + // Record start event before XCCL operation + work->recordStart(); + + const auto dataType = getXcclDataType(input); + onecclResult_t result = xccl_api_->reduceScatter( + input.data_ptr(), + output.data_ptr(), + output.numel(), + dataType, + getXcclReduceOp(op, xccl_comm_, dataType), + xccl_comm_, + stream); + + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL ReduceScatter failed", result); + } + + // Record end event after XCCL operation + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::all_to_all_single( + at::Tensor& output, + const at::Tensor& input, + bool async_op, + const AllToAllSingleOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(output); + ensureTensorContiguous(input); + + if (input.numel() != output.numel()) { + throw std::runtime_error( + "Input and output tensors must have same size for all_to_all_single"); + } + + if (input.numel() % comm_size_ != 0) { + throw std::runtime_error( + "Tensor size must be divisible by comm_size for all_to_all_single"); + } + + tracing_->recordEventWithInputOutput( + "all_to_all_single", rank_, {input}, {output}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {input}); + + // Record start event before XCCL operation + work->recordStart(); + + size_t chunk_size = input.numel() / comm_size_; + const auto data_type = getXcclDataType(input); + + onecclResult_t result = xccl_api_->allToAll( + input.data_ptr(), + output.data_ptr(), + chunk_size, + data_type, + xccl_comm_, + stream); + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL AllToAll failed", result); + } + + // Record end event after XCCL operation + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::all_to_all_v_single( + at::Tensor& output, + const at::Tensor& input, + const std::vector& output_split_sizes, + const std::vector& input_split_sizes, + bool async_op, + const AllToAllvSingleOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(output); + ensureTensorContiguous(input); + + // Validate split sizes vectors + if (input_split_sizes.size() != static_cast(comm_size_)) { + throw std::runtime_error( + "input_split_sizes length must equal comm_size for all_to_all_v_single"); + } + + if (output_split_sizes.size() != static_cast(comm_size_)) { + throw std::runtime_error( + "output_split_sizes length must equal comm_size for all_to_all_v_single"); + } + + tracing_->recordEventWithInputOutput( + "all_to_all_v_single", rank_, {input}, {output}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {input}); + + // Record start event before XCCL operation + work->recordStart(); + + // Convert split sizes to arrays and calculate displacements + std::vector sendcounts(comm_size_); + std::vector recvcounts(comm_size_); + std::vector senddispls(comm_size_); + std::vector recvdispls(comm_size_); + + // Calculate the number of elements per slice along the first dimension + // For a tensor with shape [N, D1, D2, ..., Dk], each slice of size S along + // dim 0 contains S * D1 * D2 * ... * Dk elements + size_t elements_per_slice = input.numel() ? input.numel() / input.size(0) : 0; + const auto data_type = getXcclDataType(input); + const size_t type_size = wordSize(data_type); + + size_t sendoffset = 0; + size_t recvoffset = 0; + for (int i = 0; i < comm_size_; ++i) { + sendcounts[i] = input_split_sizes[i] * elements_per_slice; + recvcounts[i] = output_split_sizes[i] * elements_per_slice; + senddispls[i] = sendoffset; + recvdispls[i] = recvoffset; + sendoffset += sendcounts[i]; + recvoffset += recvcounts[i]; + } + + char* sptr = static_cast(input.data_ptr()); + char* rptr = static_cast(output.data_ptr()); + + xccl_api_->groupStart(); + + for (int i = 0; i < comm_size_; ++i) { + xccl_api_->send( + sptr + senddispls[i] * type_size, + sendcounts[i], + data_type, + i, + xccl_comm_, + stream); + xccl_api_->recv( + rptr + recvdispls[i] * type_size, + recvcounts[i], + data_type, + i, + xccl_comm_, + stream); + } + + xccl_api_->groupEnd(); + + // Record end event after XCCL operation + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::all_to_all( + const std::vector& output_tensor_list, + const std::vector& input_tensor_list, + bool async_op, + const AllToAllOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + if (output_tensor_list.size() != static_cast(comm_size_) || + input_tensor_list.size() != static_cast(comm_size_)) { + throw std::runtime_error( + "Tensor list sizes must equal comm_size for all_to_all"); + } + + // Validate all tensors + for (int i = 0; i < comm_size_; ++i) { + ensureTensorContiguous(input_tensor_list[i]); + ensureTensorContiguous(output_tensor_list[i]); + } + + tracing_->recordEventWithInputOutput( + "all_to_all", rank_, input_tensor_list, output_tensor_list); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, + getOperationTimeout(options.timeout, options_.timeout), + input_tensor_list); + + // Record start event before XCCL operations + work->recordStart(); + + xccl_api_->groupStart(); + + for (int i = 0; i < comm_size_; ++i) { + // Send to rank i + xccl_api_->send( + input_tensor_list[i].data_ptr(), + input_tensor_list[i].numel(), + getXcclDataType(input_tensor_list[i]), + i, + xccl_comm_, + stream); + + // Receive from rank i + xccl_api_->recv( + output_tensor_list[i].data_ptr(), + output_tensor_list[i].numel(), + getXcclDataType(output_tensor_list[i]), + i, + xccl_comm_, + stream); + } + + xccl_api_->groupEnd(); + + // Record end event after XCCL operations + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::barrier( + bool async_op, + const BarrierOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + + tracing_->recordEvent("barrier"); + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {}); + + // Record start event before XCCL operation + work->recordStart(); + + // Use pre-allocated XPU buffer for barrier + onecclResult_t result = xccl_api_->allReduce( + barrier_buffer_, + barrier_buffer_, + 1, + onecclFloat32, + onecclSum, + xccl_comm_, + stream); + + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL Barrier failed", result); + } + + // Record end event after XCCL operation + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::scatter( + at::Tensor& output_tensor, + const std::vector& input_tensor_list, + int root, + bool async_op, + const ScatterOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(output_tensor); + + // Only the root rank needs valid input tensors + if (rank_ == root) { + if (input_tensor_list.size() != static_cast(comm_size_)) { + throw std::runtime_error( + "input_tensor_list size must equal comm_size for scatter"); + } + + for (const auto& t : input_tensor_list) { + ensureTensorContiguous(t); + if (t.numel() != output_tensor.numel()) { + throw std::runtime_error( + "All input tensors must have same size as output tensor"); + } + } + } + + tracing_->recordEventWithInputOutput( + "scatter", root, input_tensor_list, {output_tensor}); + + xpuStream_t stream = getOperationStream(async_op); + std::vector input_tensors; + if (rank_ == root) { + input_tensors = input_tensor_list; + } + auto work = createWork( + stream, + getOperationTimeout(options.timeout, options_.timeout), + input_tensors); + + // Record start event before XCCL operations + work->recordStart(); + + // Implement scatter using point-to-point operations + if (rank_ == root) { + // Root sends to all ranks (except itself) + xccl_api_->groupStart(); + for (int i = 0; i < comm_size_; ++i) { + if (i != root) { + xccl_api_->send( + input_tensor_list[i].data_ptr(), + input_tensor_list[i].numel(), + getXcclDataType(input_tensor_list[i]), + i, + xccl_comm_, + stream); + } + } + xccl_api_->groupEnd(); + + // Root copies its own data using xpuMemcpyAsync + XPU_CHECK( + xpu_api_, + xpu_api_->memcpyAsync( + output_tensor.data_ptr(), + input_tensor_list[root].data_ptr(), + input_tensor_list[root].numel() * + input_tensor_list[root].element_size(), + stream), + "memcpyAsync failed"); + } else { + // Non-root ranks receive from root + xccl_api_->recv( + output_tensor.data_ptr(), + output_tensor.numel(), + getXcclDataType(output_tensor), + root, + xccl_comm_, + stream); + } + + // Record end event after XCCL operations + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::gather( + const std::vector& output_tensor_list, + const at::Tensor& input_tensor, + int root, + bool async_op, + const GatherOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(input_tensor); + + // Only the root rank needs valid output tensors + if (rank_ == root) { + if (output_tensor_list.size() != static_cast(comm_size_)) { + throw std::runtime_error( + "output_tensor_list size must equal comm_size for gather"); + } + + for (const auto& t : output_tensor_list) { + ensureTensorContiguous(t); + if (t.numel() != input_tensor.numel()) { + throw std::runtime_error( + "All output tensors must have same size as input tensor"); + } + } + } + + tracing_->recordEventWithInputOutput( + "gather", root, {input_tensor}, output_tensor_list); + + xpuStream_t stream = getOperationStream(async_op); + std::vector output_tensors; + if (rank_ == root) { + output_tensors = output_tensor_list; + } + auto work = createWork( + stream, + getOperationTimeout(options.timeout, options_.timeout), + {input_tensor}); + + // Record start event before XCCL operations + work->recordStart(); + + if (rank_ == root) { + // Root receives from all ranks (except itself) + xccl_api_->groupStart(); + for (int i = 0; i < comm_size_; ++i) { + if (i != root) { + xccl_api_->recv( + output_tensor_list[i].data_ptr(), + output_tensor_list[i].numel(), + getXcclDataType(output_tensor_list[i]), + i, + xccl_comm_, + stream); + } + } + xccl_api_->groupEnd(); + + // Root copies its own data using xpuMemcpyAsync + XPU_CHECK( + xpu_api_, + xpu_api_->memcpyAsync( + output_tensor_list[root].data_ptr(), + input_tensor.data_ptr(), + input_tensor.numel() * input_tensor.element_size(), + stream), + "memcpyAsync failed"); + } else { + // Non-root ranks send to root + xccl_api_->send( + input_tensor.data_ptr(), + input_tensor.numel(), + getXcclDataType(input_tensor), + root, + xccl_comm_, + stream); + } + + // Record end event after XCCL operations + work->recordEnd(); + + // Enqueue the work after events have been recorded + enqueueWork(work, stream); + + return work; +} + +std::shared_ptr TorchCommXCCL::split( + const std::vector& ranks, + const std::string& name, + const CommOptions& options) { + // Validate the ranks list + checkAndAbortIfTimedOutOrError(); + std::unordered_set rank_seen; + for (int rank : ranks) { + if (rank < 0 || rank >= comm_size_) { + throw std::runtime_error( + "Invalid rank " + std::to_string(rank) + + " in ranks. Valid ranks are 0 to " + std::to_string(comm_size_ - 1)); + } + if (rank_seen.find(rank) != rank_seen.end()) { + throw std::runtime_error( + "Rank " + std::to_string(rank) + " appears multiple times in ranks"); + } + rank_seen.insert(rank); + } + + // Determine the color for this rank + int color; + int new_rank = -1; // Rank within the new communicator + + if (ranks.empty()) { + // Empty list means exclude all ranks - use XCCL_SPLIT_NOCOLOR +#ifdef XCCL_SPLIT_NOCOLOR + color = XCCL_SPLIT_NOCOLOR; +#else + throw std::runtime_error("XCCL_SPLIT_NOCOLOR is not defined"); +#endif + new_rank = -1; // Will not participate in new communicator + } else { + // Check if current rank is in the non-empty list + auto it = std::find(ranks.begin(), ranks.end(), rank_); + if (it == ranks.end()) { + // Current rank is not in the non-empty list - this is an error + throw std::runtime_error( + "Current rank " + std::to_string(rank_) + + " is not included in the provided ranks list"); + } + // Set color to the lowest rank in the group and calculate new rank + color = *std::min_element(ranks.begin(), ranks.end()); + new_rank = static_cast(std::distance(ranks.begin(), it)); + } + + // Create a new XCCL communicator + onecclComm_t new_comm; + onecclConfig_t config = ONECCL_CONFIG_INITIALIZER; + // Note: oneCCL does not have a commName field like NCCL + + // Populate XCCL config from user-provided hints + populateXcclConfigFromHints(config, options, name); + + // TODO: xccl says that this is not supposed to be called if any operation + // is outstanding on the comm. We should check for that. + // TODO: what happens if one rank fails but the others succeed, need to + // handle the error case. + // TODO: is this sharing any resources with the original comm? + onecclResult_t result = + xccl_api_->commSplit(xccl_comm_, color, new_rank, &new_comm, &config); + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL split failed", result); + } + if (new_rank == -1) { + return nullptr; // Rank is not in the group, return nullptr + } + + auto new_torchcomm = + std::shared_ptr(new TorchCommXCCL(new_comm)); + new_torchcomm->xccl_api_ = xccl_api_; + new_torchcomm->xpu_api_ = xpu_api_; + new_torchcomm->init(device_, name, options); + + return new_torchcomm; +} + +void TorchCommXCCL::register_address( + const TorchCommXCCL::AddressWithLen& addr) { + // We got a register after we got rid of the comm. Is this a fatal error? + if (!xccl_comm_) { + return; + } + + if (memoryRegistrationHandles_.contains(addr.addr)) { + throw std::runtime_error("Memory already registered with XCCL"); + } + void* handle = nullptr; + onecclResult_t result = + xccl_api_->commRegister(xccl_comm_, addr.addr, addr.len, &handle); + if (result != onecclSuccess) { + throw std::runtime_error( + "Failed to register memory with XCCL: " + + std::string(onecclGetErrorString(result))); + } + memoryRegistrationHandles_.emplace(addr.addr, RegistrationHandle(handle)); +} + +void TorchCommXCCL::deregister_address(const TorchCommXCCL::Address& addr) { + // We got a deregister after we got rid of the comm. Is this a fatal error? + if (!xccl_comm_) { + return; + } + + auto it = memoryRegistrationHandles_.find(addr.addr); + if (it == memoryRegistrationHandles_.end()) { + // it's possible that the memory was registered for a different comm, + // however failed registration for this comm. + return; + } + + void* handle = it->second.regHandle; + onecclResult_t result = xccl_api_->commDeregister(xccl_comm_, handle); + if (result != onecclSuccess) { + throw std::runtime_error( + "Failed to deregister memory with XCCL: " + + std::string(xccl_api_->getErrorString(result))); + } + + memoryRegistrationHandles_.erase(it); +} + +XCCLException::XCCLException( + XcclApi& xccl_api, + const std::string& message, + onecclResult_t result) + : message_(message + ": " + xccl_api.getErrorString(result)), + result_(result) {} + +const char* XCCLException::what() const noexcept { + return message_.c_str(); +} + + + +} // namespace comms +} // namespace torch + +namespace { +class XCCLRegistration { + public: + XCCLRegistration() { + torch::comms::TorchCommFactory::get().register_backend("xccl", []() { + return std::make_shared(); + }); + } +}; + +static XCCLRegistration registration{}; +} // namespace diff --git a/comms/torchcomms/xccl/TorchCommXCCL.hpp b/comms/torchcomms/xccl/TorchCommXCCL.hpp new file mode 100644 index 00000000..a4b3af2b --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCL.hpp @@ -0,0 +1,374 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include // @manual=//caffe2:torch-cpp + +#include "comms/torchcomms/TorchComm.hpp" +#include "comms/torchcomms/TorchCommBackend.hpp" +#include "comms/torchcomms/TorchCommBatch.hpp" +#include "comms/torchcomms/TorchCommTracing.hpp" +#include "comms/torchcomms/device/XpuApi.hpp" +#include "comms/torchcomms/xccl/TorchWorkXCCL.hpp" +#include "comms/torchcomms/xccl/XcclApi.hpp" + +namespace torch { +namespace comms { + +constexpr size_t kMaxEventPoolSize = 1000; + +// Custom exception class for better error handling +class XCCLException : public std::exception { + public: + XCCLException(XcclApi& api, const std::string& message, onecclResult_t result); + + const char* what() const noexcept override; + onecclResult_t getResult() const; + + private: + std::string message_; + onecclResult_t result_; +}; + +class TorchCommXCCL : public TorchCommBackend, + public std::enable_shared_from_this { + public: + static constexpr std::string_view kBackendName = "xccl"; + + TorchCommXCCL(); + ~TorchCommXCCL() override; + + // Delete copy and move operations + TorchCommXCCL(const TorchCommXCCL&) = delete; + TorchCommXCCL(TorchCommXCCL&&) = delete; + TorchCommXCCL& operator=(const TorchCommXCCL&) = delete; + TorchCommXCCL& operator=(TorchCommXCCL&&) = delete; + + void init( + at::Device device, + const std::string& name, + const CommOptions& options = {}) override; + void finalize() override; + int getRank() const override; + int getSize() const override; + std::string_view getBackendName() const override; + std::string_view getCommName() const override; + + // Point-to-Point Operations + std::shared_ptr send( + const at::Tensor& tensor, + int dst, + bool async_op, + const SendOptions& options = {}) override; + std::shared_ptr recv( + at::Tensor& tensor, + int src, + bool async_op, + const RecvOptions& options = {}) override; + + // Batch P2P Operations + std::shared_ptr batch_op_issue( + const std::vector& ops, + bool async_op, + const BatchP2POptions& options = {}) override; + + // Collective Operations + std::shared_ptr broadcast( + at::Tensor& tensor, + int root, + bool async_op, + const BroadcastOptions& options = {}) override; + std::shared_ptr all_reduce( + at::Tensor& tensor, + ReduceOp op, + bool async_op, + const AllReduceOptions& options = {}) override; + std::shared_ptr reduce( + const at::Tensor& tensor, + int root, + ReduceOp op, + bool async_op, + const ReduceOptions& options = {}) override; + std::shared_ptr all_gather( + const std::vector& tensor_list, + const at::Tensor& tensor, + bool async_op, + const AllGatherOptions& options = {}) override; + std::shared_ptr all_gather_single( + at::Tensor& output, + const at::Tensor& input, + bool async_op, + const AllGatherSingleOptions& options = {}) override; + std::shared_ptr reduce_scatter( + at::Tensor& output, + const std::vector& input_list, + ReduceOp op, + bool async_op, + const ReduceScatterOptions& options = {}) override; + std::shared_ptr reduce_scatter_single( + at::Tensor& output, + const at::Tensor& input, + ReduceOp op, + bool async_op, + const ReduceScatterSingleOptions& options = {}) override; + std::shared_ptr all_to_all_single( + at::Tensor& output, + const at::Tensor& input, + bool async_op, + const AllToAllSingleOptions& options = {}) override; + std::shared_ptr all_to_all_v_single( + at::Tensor& output, + const at::Tensor& input, + const std::vector& output_split_sizes, + const std::vector& input_split_sizes, + bool async_op, + const AllToAllvSingleOptions& options = {}) override; + std::shared_ptr all_to_all( + const std::vector& output_tensor_list, + const std::vector& input_tensor_list, + bool async_op, + const AllToAllOptions& options = {}) override; + std::shared_ptr barrier( + bool async_op, + const BarrierOptions& options = {}) override; + + // Scatter and Gather Operations + std::shared_ptr scatter( + at::Tensor& output_tensor, + const std::vector& input_tensor_list, + int root, + bool async_op, + const ScatterOptions& options = {}) override; + std::shared_ptr gather( + const std::vector& output_tensor_list, + const at::Tensor& input_tensor, + int root, + bool async_op, + const GatherOptions& options = {}) override; + + // Communicator Management + std::shared_ptr split( + const std::vector& ranks, + const std::string& name, + const CommOptions& options = {}) override; + + // Friend access for TorchCommXCCL + friend class TorchWorkXCCL; + // NOTE: CachingAllocatorHook is not implemented for XPU/SYCL yet + // friend class CachingAllocatorHookImpl; + friend class TorchCommWindowXCCL; + + // Getter for CUDA API (for friend classes) + XpuApi* getXpuApi() const { + return xpu_api_.get(); + } + + // Getter for XCCL API (for friend classes) + XcclApi* getXcclApi() const { + return xccl_api_.get(); + } + + // Method to override the XCCL API implementation for testing + void setXcclApi(std::shared_ptr api) { + xccl_api_ = std::move(api); + } + + // Method to override the CUDA API implementation for testing + void setXpuApi(std::shared_ptr api) { + xpu_api_ = std::move(api); + } + + const CommOptions& getOptions() const override { + return options_; + } + + const at::Device& getDevice() const override { + return device_; + } + + protected: + // Event management for friend classes + xpuEvent_t getEvent(); + void returnEvent(xpuEvent_t&& event); + void abortXcclComm(); + + enum class CommState { + NORMAL, + ERROR, + TIMEOUT, + }; + + struct Address { + void* addr; + }; + + struct AddressWithLen { + void* addr; + size_t len; + }; + + std::atomic comm_state_{ + CommState::NORMAL}; // State of the communicator + + void register_address(const AddressWithLen& addr); + void deregister_address(const Address& addr); + onecclDataType_t getXcclDataType(const at::Tensor& tensor); + std::shared_ptr createWork( + xpuStream_t stream, + std::chrono::milliseconds timeout, + const std::vector& inputTensors); + + private: + // Helper that automatically cleans up premul sums. + struct RedOpRAII { + /* implicit */ RedOpRAII(onecclRedOp_t op); + + // Constructor for Premulsum Reduction + explicit RedOpRAII( + const ReduceOp& op, + const onecclComm_t comm, + const onecclDataType_t dataType, + std::shared_ptr xccl_api); + + RedOpRAII() = delete; + RedOpRAII(const RedOpRAII&) = delete; + RedOpRAII& operator=(const RedOpRAII&) = delete; + RedOpRAII(RedOpRAII&& tmp) = delete; + RedOpRAII& operator=(RedOpRAII&&) = delete; + ~RedOpRAII(); + + operator onecclRedOp_t() const { + return xcclRedOp_; + } + + onecclRedOp_t xcclRedOp_{onecclMaxRedOp}; + onecclComm_t comm_{nullptr}; + std::shared_ptr xccl_api_; + }; + + // Struct to hold the registration handle for a buffer + struct RegistrationHandle { + void* regHandle; + + explicit RegistrationHandle(void* regHandle) : regHandle{regHandle} {} + + RegistrationHandle(RegistrationHandle&& other) noexcept + : regHandle{other.regHandle} { + other.regHandle = nullptr; + } + + RegistrationHandle(const RegistrationHandle&) = delete; + RegistrationHandle& operator=(const RegistrationHandle&) = delete; + RegistrationHandle& operator=(RegistrationHandle&&) = delete; + + ~RegistrationHandle() = default; + }; + + // Constructor for split communicators + explicit TorchCommXCCL(const onecclComm_t xccl_comm); + + // Private utility methods + size_t wordSize(onecclDataType_t type) const; + RedOpRAII getXcclReduceOp( + const ReduceOp& op, + const onecclComm_t comm, + const onecclDataType_t dataType); + void timeoutWatchdog() noexcept; + void checkInitialized() const; + void checkAndAbortIfTimedOutOrError(); + void checkWorkQueue(bool isMainThread); + void enqueueWork(std::shared_ptr work, xpuStream_t stream); + // XPU doesn't support graph capture yet + // bool getGraphCaptureMode(); + xpuStream_t getOperationStream(bool async_op); + void ensureTensorContiguous(const at::Tensor& tensor); + + void attachMemoryHook(); + void detachMemoryHook(); + + // Member variables + onecclComm_t xccl_comm_{}; + at::Device device_; + int comm_size_{}; + int rank_{}; + CommOptions options_; + size_t max_event_pool_size_{}; + std::optional internal_stream_; // Initialized in init() + std::optional + dependency_event_; // Pre-allocated event for stream dependencies + void* barrier_buffer_{}; // Pre-allocated CUDA buffer for barrier operations + enum class InitializationState { + UNINITIALIZED, + INITIALIZED, + FINALIZED, + } init_state_; + + // List of [comm, regHandlesMap] pairs. Each regHandlesMap is a map from the + // buffer address to the registeration handle + std::map memoryRegistrationHandles_; + + // XCCL API abstraction + std::shared_ptr xccl_api_; + + // XPU API abstraction + std::shared_ptr xpu_api_; + + // Event pool management + std::queue event_pool_; + std::mutex event_pool_mutex_; + + // Work tracking per stream + TorchWorkXCCLQueue workq_; + + // Timeout monitoring + std::thread timeout_thread_; + std::atomic shutdown_; + std::condition_variable timeout_cv_; + std::mutex timeout_mutex_; + + std::shared_ptr tracing_; + bool high_priority_stream_{false}; + std::string name_; + + // Graph capture mode work references + // Keep references to work objects during graph capture to prevent premature + // destruction, organized per graph using capture ID + std::unordered_map< + unsigned long long, + std::vector>> + graph_capture_work_refs_; + std::mutex graph_capture_work_mutex_; + + // Structure to hold cleanup data for XPU user objects + // NOTE: Graph capture cleanup is currently disabled for XPU/SYCL + // as the required APIs (userObjectCreate, graphRetainUserObject) are not yet available + struct GraphCleanupData { + TorchCommXCCL* comm; + unsigned long long graph_id; + + GraphCleanupData(TorchCommXCCL* comm_, unsigned long long id) + : comm(comm_), graph_id(id) {} + }; + + // Static callback function for XPU user object cleanup + // NOTE: Currently disabled - XPU/SYCL does not have equivalent callback mechanism + // static void graphCleanupCallback(void* userData); + + friend class TorchWorkXCCLQueueCommTest; +}; + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp new file mode 100644 index 00000000..5462d5bf --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp @@ -0,0 +1,307 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp" +#include +#include // @manual +#include "comms/torchcomms/StoreManager.hpp" +#include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/TorchCommUtils.hpp" +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" + +namespace torch { +namespace comms { + +// Initialize the static counter +int TorchCommXCCLBootstrap::counter_ = 0; + +const std::string kUniqueidXchgMethodAuto = "auto"; +const std::string kUniqueidXchgMethodTCPStore = "tcpstore"; +const std::string kUniqueidXchgMethodDefault = kUniqueidXchgMethodAuto; + +TorchCommXCCLBootstrap::TorchCommXCCLBootstrap( + c10::intrusive_ptr store, + c10::Device device, + std::shared_ptr xccl_api, + std::shared_ptr xpu_api, + std::chrono::milliseconds timeout) + : timeout_(timeout), + store_(store), + created_internal_store_(false), + device_(device), + xccl_api_(xccl_api), + xpu_api_(xpu_api) { + // Query rank and size using the utility function + auto ranksize = query_ranksize(); + rank_ = ranksize.first; + comm_size_ = ranksize.second; + + const char* uniqueid_xchg_env = + std::getenv("TORCHCOMM_XCCL_BOOTSTRAP_UNIQUEID_EXCHANGE_METHOD"); + if (uniqueid_xchg_env == nullptr) { + TC_LOG(INFO) + << "TORCHCOMM_XCCL_BOOTSTRAP_UNIQUEID_EXCHANGE_METHOD not set, " + << "defaulting to " << kUniqueidXchgMethodDefault; + uniqueid_xchg_method_ = kUniqueidXchgMethodDefault; + } else { + uniqueid_xchg_method_ = uniqueid_xchg_env; + } + std::transform( + uniqueid_xchg_method_.begin(), + uniqueid_xchg_method_.end(), + uniqueid_xchg_method_.begin(), + [](unsigned char c) { return std::tolower(c); }); + + if (device_.index() == -1) { + int device_count; + XPU_CHECK( + xpu_api_, + xpu_api_->getDeviceCount(&device_count), + "Failed to get XPU device count"); + + device_ = c10::Device(c10::kXPU, rank_ % device_count); + TC_LOG(INFO) << "User did not provide device ID; using device xpu:" + << static_cast(device_.index()); + } + + XPU_CHECK( + xpu_api_, + xpu_api_->setDevice(device_.index()), + "Failed to set device to " + std::to_string(device_.index())); + + // Allocate XPU memory for a single float32 value used in barrier operations + XPU_CHECK( + xpu_api_, + xpu_api_->malloc(&barrier_buffer_, sizeof(float)), + "Failed to allocate barrier buffer"); +} + +TorchCommXCCLBootstrap::~TorchCommXCCLBootstrap() { + if (barrier_buffer_ != nullptr) { + XPU_CHECK( + xpu_api_, + xpu_api_->free(barrier_buffer_), + "Failed to free barrier buffer"); + barrier_buffer_ = nullptr; + } +} + +std::string TorchCommXCCLBootstrap::getXCCLStoreKey() { + std::string key = getXCCLStoreKeyPrefix() + std::to_string(counter_); + counter_++; + return key; +} + +std::string TorchCommXCCLBootstrap::getXCCLStoreKeyPrefix() { + return "xccl_storekey_"; +}; + +int TorchCommXCCLBootstrap::getXCCLStoreKeyCounter() { + return counter_; +} + +onecclUniqueId TorchCommXCCLBootstrap::exchangeUniqueIdStore() { + onecclUniqueId uniqueId; + + auto key = getXCCLStoreKey(); + TC_LOG(INFO) << "[TC] Using store key: " << key << " for rank " << rank_; + + if (rank_ == 0) { + // Generate unique ID on rank 0 + TC_LOG(INFO) << "[TC] Rank 0: calling getUniqueId"; + onecclResult_t xcclErr = xccl_api_->getUniqueId(&uniqueId); + TC_LOG(INFO) << "[TC] Rank 0: getUniqueId returned " << xcclErr; + + if (xcclErr != onecclSuccess) { + throw std::runtime_error( + "Failed to get XCCL unique ID: " + + std::string(xccl_api_->getErrorString(xcclErr))); + } + + // Set the unique ID in the store + TC_LOG(INFO) << "[TC] Rank 0: setting unique ID in store"; + std::vector vec( + reinterpret_cast(&uniqueId), + reinterpret_cast(&uniqueId) + sizeof(uniqueId)); + store_->set(key, vec); + TC_LOG(INFO) << "[TC] Rank 0: unique ID set in store"; + } else { + // Other ranks read the broadcast ID + TC_LOG(INFO) << "[TC] Rank " << rank_ << ": getting unique ID from store"; + auto vec = store_->get(key); + TC_LOG(INFO) << "[TC] Rank " << rank_ << ": got unique ID from store, size=" << vec.size(); + + if (vec.size() != sizeof(onecclUniqueId)) { + throw std::runtime_error("Invalid XCCL unique ID size"); + } + uniqueId = *(reinterpret_cast(vec.data())); + } + + TC_LOG(INFO) << "[TC] Rank " << rank_ << ": unique ID exchange completed"; + return uniqueId; +} + +onecclUniqueId TorchCommXCCLBootstrap::exchangeUniqueIdTCPStore( + std::string_view name) { + store_ = + StoreManager::get().getStore(TorchCommXCCL::kBackendName, name, timeout_); + created_internal_store_ = true; + + return exchangeUniqueIdStore(); +} + +bool TorchCommXCCLBootstrap::isTCPStoreEnabled() { + return std::getenv("MASTER_ADDR") && std::getenv("MASTER_PORT"); +} + +onecclUniqueId TorchCommXCCLBootstrap::exchangeUniqueId(std::string_view name) { + if (store_ != nullptr) { + return exchangeUniqueIdStore(); + } + + bool is_tcp_store_enabled = isTCPStoreEnabled(); + if (uniqueid_xchg_method_ != kUniqueidXchgMethodAuto && + uniqueid_xchg_method_ != kUniqueidXchgMethodTCPStore) { + throw std::runtime_error( + "Invalid unique ID exchange method " + uniqueid_xchg_method_); + } + if (!is_tcp_store_enabled) { + throw std::runtime_error("No way to exchange unique ID"); + } + return exchangeUniqueIdTCPStore(name); +} + +void TorchCommXCCLBootstrap::cleanupTCPStore(onecclComm_t xccl_comm) { + if (created_internal_store_) { + // Delete the internal store object and do a barrier to ensure that all + // processes have deleted their store object too. This way, when we + // create the next torchcomm, we can use the same port to create a new store + // object. + store_.reset(); + + auto stream = xpu_api_->getCurrentXPUStream(device_.index()); + onecclResult_t result = xccl_api_->allReduce( + barrier_buffer_, + barrier_buffer_, + 1, + onecclFloat32, + onecclSum, + xccl_comm, + stream); + if (result != onecclSuccess) { + TC_LOG(ERROR) << "XCCL AllReduce failed: " + << xccl_api_->getErrorString(result); + } + + XPU_CHECK( + xpu_api_, + xpu_api_->streamSynchronize(stream), + "Stream synchronization failed"); + } +} + +// Helper function to populate XCCL config from hints +void populateXcclConfigFromHints( + onecclConfig_t& config, + const CommOptions& options, + const std::string& name) { + // Iterate over the hints and set the corresponding fields in the config. For + // string arguments, XCCL uses a "const char*" instead of a std::string, so + // it is hard to figure out the ownership structure. Here, we create a copy + // of the string and pass it to XCCL, so that it is responsible for freeing + // it. + + for (const auto& [key, val] : options.hints) { + if (key == "blocking") { + config.blocking = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.blocking=" << config.blocking; + } else if (key == "cgaClusterSize" || key == "cga_cluster_size") { + config.cgaClusterSize = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name << "] Setting config.cgaClusterSize=" + << config.cgaClusterSize; + } else if (key == "minCTAs" || key == "min_ctas") { + config.minCTAs = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.minCTAs=" << config.minCTAs; + } else if (key == "maxCTAs" || key == "max_ctas") { + config.maxCTAs = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.maxCTAs=" << config.maxCTAs; + } else if (key == "netName") { + config.netName = strdup(val.c_str()); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.netName=" << config.netName; + } else if (key == "splitShare" || key == "split_share") { + config.splitShare = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.splitShare=" << config.splitShare; + } else if (key == "trafficClass" || key == "traffic_class" || + key == "commName" || + key == "collnetEnable" || key == "collnet_enable" || + key == "CTAPolicy" || key == "cta_policy" || + key == "shrinkShare" || + key == "nvlsCTAs" || key == "nvls_ctas" || + key == "nChannelsPerNetPeer" || key == "n_channels_per_net_peer" || + key == "nvlinkCentricSched" || key == "nvlink_centric_sched") { + TC_LOG(WARNING) + << "XCCL hint '" << key + << "' is NCCL-specific and not supported by oneCCL, ignoring for comm '" + << name << "'"; + } else { + TC_LOG(WARNING) + << "XCCL hint '" << key + << "' is not supported in this XCCL version, ignoring for comm '" + << name << "'"; + } + } +} + +onecclComm_t TorchCommXCCLBootstrap::createXcclComm( + const std::string& name, + const CommOptions& options) { + onecclUniqueId uniqueId; + onecclComm_t xccl_comm = nullptr; + + TC_LOG(INFO) << "[TC] Exchanging unique ID for comm '" << name << "'"; + uniqueId = exchangeUniqueId(name); + TC_LOG(INFO) << "[TC] Unique ID exchanged"; + + // TODO: add logging on failures and successes + // TODO: use scalable init + // TODO: get the local rank + TC_LOG(INFO) << "[TC] Initializing XCCL config"; + onecclConfig_t config = ONECCL_CONFIG_INITIALIZER; + // Note: oneCCL does not have a commName field like NCCL + + // Populate XCCL config from user-provided hints + populateXcclConfigFromHints(config, options, name); + + // Set device for oneCCL before initializing communicator + TC_LOG(INFO) << "[TC] Setting oneCCL device to " << device_.index(); + onecclResult_t xcclErr = xccl_api_->setDevice(device_.index()); + if (xcclErr != onecclSuccess) { + throw std::runtime_error( + "Failed to set oneCCL device: " + + std::string(xccl_api_->getErrorString(xcclErr))); + } + + TC_LOG(INFO) << "[TC] Calling commInitRankConfig with rank=" << rank_ + << " comm_size=" << comm_size_; + xcclErr = xccl_api_->commInitRankConfig( + &xccl_comm, comm_size_, uniqueId, rank_, &config); + TC_LOG(INFO) << "[TC] commInitRankConfig returned: " << xcclErr; + + if (xcclErr != onecclSuccess || xccl_comm == nullptr) { + throw std::runtime_error( + "Failed to initialize XCCL communicator: " + + std::string(xccl_api_->getErrorString(xcclErr))); + } + + TC_LOG(INFO) << "[TC] XCCL communicator initialized, cleaning up TCPStore"; + cleanupTCPStore(xccl_comm); + + return xccl_comm; +} + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp new file mode 100644 index 00000000..152611ef --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp @@ -0,0 +1,85 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#pragma once + +#include + +#include +// #include // @manual=third-party//xpu:xpu-lazy +#include // @manual=//caffe2:torch-cpp + +#include "comms/torchcomms/TorchCommOptions.hpp" +#include "comms/torchcomms/device/XpuApi.hpp" +#include "comms/torchcomms/xccl/XcclApi.hpp" +#include +#include + +namespace torch { +namespace comms { + +constexpr uint16_t kTCPStorePort = 29500; + +class TorchCommXCCLBootstrap { + public: + TorchCommXCCLBootstrap( + c10::intrusive_ptr store, + c10::Device device, + std::shared_ptr xccl_api, + std::shared_ptr xpu_api, + std::chrono::milliseconds timeout); + ~TorchCommXCCLBootstrap(); + + // Delete copy and move operations + TorchCommXCCLBootstrap(const TorchCommXCCLBootstrap&) = delete; + TorchCommXCCLBootstrap& operator=(const TorchCommXCCLBootstrap&) = delete; + TorchCommXCCLBootstrap(TorchCommXCCLBootstrap&&) = delete; + TorchCommXCCLBootstrap& operator=(TorchCommXCCLBootstrap&&) = delete; + + onecclComm_t createXcclComm( + const std::string& name, + const CommOptions& options = {}); + static std::string getXCCLStoreKey(); + static std::string getXCCLStoreKeyPrefix(); + static int getXCCLStoreKeyCounter(); + + int getRank() { + return rank_; + } + int getSize() { + return comm_size_; + } + c10::Device getDevice() { + return device_; + } + + private: + onecclUniqueId exchangeUniqueId(std::string_view name); + onecclUniqueId exchangeUniqueIdStore(); + onecclUniqueId exchangeUniqueIdTCPStore(std::string_view name); + bool isTCPStoreEnabled(); + void cleanupTCPStore(onecclComm_t xccl_comm); + + private: + const std::chrono::milliseconds timeout_; + static int counter_; + + c10::intrusive_ptr store_; + bool created_internal_store_; + c10::Device device_; + std::shared_ptr xccl_api_; + std::shared_ptr xpu_api_; + void* barrier_buffer_{nullptr}; + int rank_; + int comm_size_; + + std::string uniqueid_xchg_method_; +}; + +// Helper function to populate XCCL config from hints +void populateXcclConfigFromHints( + onecclConfig_t& config, + const CommOptions& options, + const std::string& name); + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchCommXCCLPy.cpp b/comms/torchcomms/xccl/TorchCommXCCLPy.cpp new file mode 100644 index 00000000..aa1780ed --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCLPy.cpp @@ -0,0 +1,18 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include +#include +#include +#include +#include + +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" + +namespace py = pybind11; +using namespace torch::comms; + +PYBIND11_MODULE(_comms_xccl, m) { + m.doc() = "XCCL specific python bindings for TorchComm"; + + py::class_>(m, "TorchCommXCCL"); +} diff --git a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp new file mode 100644 index 00000000..79c35d94 --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp @@ -0,0 +1,458 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" +// #include "comms/torchcomms/xccl/TorchCommXCCLCCA.hpp" + +#include +#include +#include "comms/torchcomms/TorchCommLogging.hpp" +#include +#include + +namespace torch { +namespace comms { + +namespace { + +onecclDataType_t getXcclDataTypeInternal(const at::Tensor& tensor) { + switch (tensor.scalar_type()) { + case at::ScalarType::Float: + return onecclFloat32; + case at::ScalarType::Double: + return onecclFloat64; + case at::ScalarType::Half: + return onecclFloat16; + case at::ScalarType::BFloat16: + return onecclBfloat16; + case at::ScalarType::Int: + return onecclInt32; + case at::ScalarType::Long: + return onecclInt64; + case at::ScalarType::Char: + return onecclInt8; + case at::ScalarType::Byte: + return onecclUint8; + default: + throw std::runtime_error("Unsupported tensor data type for XCCL"); + } +} + +template +void createPreMulSum( + onecclRedOp_t* op, + const PreMulSumFactorT& factor, + const onecclComm_t& comm, + XcclApi* xccl_api) { + const bool is_tensor = std::holds_alternative(factor); + const auto residence = is_tensor ? onecclScalarDevice : onecclScalarHostImmediate; + + at::Tensor tensor = is_tensor ? std::get(factor) : at::Tensor(); + T scalar_factor = is_tensor ? T{} : static_cast(std::get(factor)); + void* scalar = is_tensor ? tensor.data_ptr() : &scalar_factor; + + TORCH_INTERNAL_ASSERT( + is_tensor ? dataType == getXcclDataTypeInternal(tensor) + : dataType != onecclBfloat16, + "PreMulSum factor type must match input data type"); + xccl_api->redOpCreatePreMulSum(op, scalar, dataType, residence, comm); +} + +} // namespace + +TorchCommXCCL::RedOpRAII::RedOpRAII(onecclRedOp_t op) + : xcclRedOp_(op), comm_(nullptr) {} + +TorchCommXCCL::RedOpRAII::RedOpRAII( + const ReduceOp& op, + const onecclComm_t comm, + const onecclDataType_t dataType, + std::shared_ptr xccl_api) + : comm_(comm), xccl_api_(std::move(xccl_api)) { + TORCH_INTERNAL_ASSERT( + op == ReduceOp::RedOpType::PREMUL_SUM, + "Constructing premul_sum RedOpRAII with non-premul_sum RedOpType"); + + if (!op.factor().has_value()) { + xcclRedOp_ = onecclSum; + comm_ = nullptr; + return; + } + + const auto& factor = op.factor().value(); + switch (dataType) { + case onecclFloat16: + createPreMulSum( + &xcclRedOp_, factor, comm, xccl_api_.get()); + break; + case onecclFloat32: + createPreMulSum( + &xcclRedOp_, factor, comm, xccl_api_.get()); + break; + case onecclBfloat16: + createPreMulSum( + &xcclRedOp_, factor, comm, xccl_api_.get()); + break; + case onecclFloat64: + createPreMulSum( + &xcclRedOp_, factor, comm, xccl_api_.get()); + break; + default: + throw std::runtime_error( + "PreMulSum Data type must be half, float, bfloat16 or double"); + } +} + +TorchCommXCCL::RedOpRAII::~RedOpRAII() { + if (comm_) { + xccl_api_->redOpDestroy(xcclRedOp_, comm_); + } +} + +size_t TorchCommXCCL::wordSize(onecclDataType_t type) const { + switch (type) { + case onecclChar: +#if XCCL_MAJOR >= 2 + // case onecclInt8: + case onecclUint8: +#endif +// #if HAVE_FP8 +// case onecclFloat8e4m3: +// case onecclFloat8e5m2: +// #endif + return 1; + case onecclHalf: +#if HAVE_BF16 + case onecclBfloat16: +#endif + // case onecclFloat16: + return 2; + case onecclInt: + case onecclFloat: +#if XCCL_MAJOR >= 2 + // case onecclInt32: + case onecclUint32: + // case onecclFloat32: +#endif + return 4; + case onecclInt64: + case onecclUint64: + case onecclDouble: + // case onecclFloat64: + return 8; + default: + return 0; + } +} + +onecclDataType_t TorchCommXCCL::getXcclDataType(const at::Tensor& tensor) { + return getXcclDataTypeInternal(tensor); +} + +TorchCommXCCL::RedOpRAII TorchCommXCCL::getXcclReduceOp( + const ReduceOp& op, + const onecclComm_t comm, + const onecclDataType_t dataType) { + switch (op) { + case ReduceOp::RedOpType::SUM: + return onecclSum; + case ReduceOp::RedOpType::PRODUCT: + return onecclProd; + case ReduceOp::RedOpType::MIN: + return onecclMin; + case ReduceOp::RedOpType::MAX: + return onecclMax; + case ReduceOp::RedOpType::BAND: + return onecclSum; // XCCL doesn't have bitwise AND, using SUM as fallback + case ReduceOp::RedOpType::BOR: + return onecclSum; // XCCL doesn't have bitwise OR, using SUM as fallback + case ReduceOp::RedOpType::BXOR: + return onecclSum; // XCCL doesn't have bitwise XOR, using SUM as fallback + case ReduceOp::RedOpType::PREMUL_SUM: + return RedOpRAII(op, comm, dataType, xccl_api_); + case ReduceOp::RedOpType::AVG: + return onecclAvg; + default: + throw std::runtime_error("Unsupported reduce operation"); + } +} + +void TorchCommXCCL::checkWorkQueue(bool isMainThread) { + TorchWorkXCCL::WorkStatus status = workq_.garbageCollect(isMainThread); + + switch (status) { + case TorchWorkXCCL::WorkStatus::TIMEDOUT: + comm_state_ = CommState::TIMEOUT; + break; + case TorchWorkXCCL::WorkStatus::ERROR: + comm_state_ = CommState::ERROR; + break; + default: + // For COMPLETED, NOT_STARTED, and INPROGRESS, no state change needed + break; + } +} + +// The timeout thread cannot make XCCL calls. The only XPU call it can make +// it xpuEventQuery. +void TorchCommXCCL::timeoutWatchdog() noexcept { + TC_LOG(INFO) << "Timeout thread starting for rank: " << rank_; + while (!shutdown_) { + { + std::unique_lock lock(timeout_mutex_); + // Wait for a shorter interval to check work objects periodically + // Wake up either after 1 second or immediately if shutdown is requested + timeout_cv_.wait_for( + lock, std::chrono::seconds(1), [this]() { return shutdown_.load(); }); + + // If we're shutting down, exit the loop + if (shutdown_) { + break; + } + } + + // Check work objects for completion or timeout + checkWorkQueue(false); + if (comm_state_ != CommState::NORMAL && + options_.abort_process_on_timeout_or_error) { + // Log the error and abort the process. We cannot abort the XCCL + // communicator as it is not safe to call XCCL operations from + // multiple threads at the same time. + if (comm_state_ == CommState::TIMEOUT) { + TC_LOG(ERROR) << "Aborting process due to timeout on rank " << rank_ + << " - timeout watchdog detected operation timeout"; + } else if (comm_state_ == CommState::ERROR) { + TC_LOG(ERROR) << "Aborting process due to error on rank " << rank_ + << " - timeout watchdog detected operation error. "; + } + abort(); + } + } + + TC_LOG(INFO) << "Timeout thread exiting for rank: " << rank_; +} + +void TorchCommXCCL::checkInitialized() const { + if (init_state_ != InitializationState::INITIALIZED) { + throw std::runtime_error("TorchCommXCCL not initialized"); + } +} + +void TorchCommXCCL::checkAndAbortIfTimedOutOrError() { + // Nothing to check in graph capture mode + // XPU doesn't support graph capture yet + // if (getGraphCaptureMode()) { + // return; + // } + + // First, check work queue status + checkWorkQueue(true); + + if (comm_state_ == CommState::TIMEOUT) { + abortXcclComm(); + if (options_.abort_process_on_timeout_or_error) { + TC_LOG(ERROR) << "Aborting process due to timeout"; + abort(); + } else { + throw std::runtime_error("XCCL operation timed out"); + } + } else if (comm_state_ == CommState::ERROR) { + onecclResult_t asyncErr; + xccl_api_->commGetAsyncError(xccl_comm_, &asyncErr); + XCCLException xcclException(*xccl_api_, "XCCL Async Error", asyncErr); + abortXcclComm(); + if (options_.abort_process_on_timeout_or_error) { + TC_LOG(ERROR) << "Aborting process due to error: " + << xcclException.what(); + abort(); + } else { + throw xcclException; + } + } +} + +// bool TorchCommXCCL::getGraphCaptureMode() { +// xpuStream_t current_stream = +// xpu_api_->getCurrentXPUStream(device_.index()); +// xpuStreamCaptureStatus capture_status; + +// xpuError_t err = +// xpu_api_->streamIsCapturing(current_stream, &capture_status); +// if (err == xpuSuccess) { +// return capture_status == xpuStreamCaptureStatusActive; +// } + +// throw std::runtime_error( +// "Failed to check XPU stream capture status: " + +// std::string(xpu_api_->getErrorString(err))); +// } + +std::shared_ptr TorchCommXCCL::createWork( + xpuStream_t stream, + std::chrono::milliseconds timeout, + const std::vector& inputTensors) { + // Only create the work object without enqueuing it + auto work = std::make_shared( + shared_from_this(), stream, timeout, inputTensors, tracing_); + return work; +} + +void TorchCommXCCL::enqueueWork( + std::shared_ptr work, + xpuStream_t stream) { + // In graph capture mode, keep a reference to the work object to prevent + // premature destruction until the graph gets destroyed, organized per graph + // if (getGraphCaptureMode()) { + // xpuStreamCaptureStatus capture_status; + // unsigned long long graph_id; + // xpuGraph_t graph; + + // xpuError_t err = xpu_api_->streamGetCaptureInfo_v2( + // stream, &capture_status, &graph_id, &graph, nullptr, nullptr); + // if (err != xpuSuccess) { + // throw std::runtime_error( + // "Failed to get XPU stream capture info: " + + // std::string(xpu_api_->getErrorString(err))); + // } else if (capture_status == xpuStreamCaptureStatusActive) { + // std::lock_guard lock(graph_capture_work_mutex_); + + // // Check if this is the first work object for this graph + // bool is_first_work = graph_capture_work_refs_[graph_id].empty(); + + // // Add work reference to the per-graph container + // graph_capture_work_refs_[graph_id].push_back(work); + + // // If this is the first work object for this graph, set up automatic + // // cleanup + // if (is_first_work) { + // // Create cleanup data that will be passed to the callback + // auto* cleanup_data = new GraphCleanupData(this, graph_id); + + // // Create a XPU user object with our cleanup callback + // xpuUserObject_t user_object; + // err = xpu_api_->userObjectCreate( + // &user_object, + // cleanup_data, + // graphCleanupCallback, + // 1, // initial reference count + // xpuUserObjectNoDestructorSync); + // if (err != xpuSuccess) { + // // If we failed to create the user object, clean up manually + // delete cleanup_data; + // throw std::runtime_error( + // "Failed to create user object: " + + // std::string(xpu_api_->getErrorString(err))); + // } else { + // // Retain the user object in the graph so it gets cleaned up when the + // // graph is destroyed + // err = xpu_api_->graphRetainUserObject( + // graph, + // user_object, + // 1, // reference count + // xpuGraphUserObjectMove); + // if (err != xpuSuccess) { + // // If we failed to retain the user object, clean up manually + // delete cleanup_data; + // throw std::runtime_error( + // "Failed to retain user object: " + + // std::string(xpu_api_->getErrorString(err))); + // } + // } + // } + // } + // } else { + // Add work to stream's queue after events have been recorded + workq_.enqueueWork(std::move(work), stream); + // } +} + +// // Static callback function for XPU user object cleanup +// void XPURT_CB TorchCommXCCL::graphCleanupCallback(void* userData) { +// auto* cleanup_data = static_cast(userData); +// if (cleanup_data == nullptr || cleanup_data->comm == nullptr) { +// throw std::runtime_error("Invalid cleanup data"); +// } + +// // Clear the work references for this graph +// std::lock_guard lock( +// cleanup_data->comm->graph_capture_work_mutex_); +// cleanup_data->comm->graph_capture_work_refs_.erase(cleanup_data->graph_id); + +// // Clean up the cleanup data itself +// delete cleanup_data; +// } + +xpuStream_t TorchCommXCCL::getOperationStream(bool async_op) { + if (async_op) { + // Get current PyTorch XPU stream for this device + xpuStream_t current_stream = + xpu_api_->getCurrentXPUStream(device_.index()); + + // Record event on current stream and wait for it on internal stream + XPU_CHECK( + xpu_api_, + xpu_api_->eventRecord(dependency_event_.value(), current_stream), + "Failed to record dependency event"); + + XPU_CHECK( + xpu_api_, + xpu_api_->streamWaitEvent(internal_stream_.value(), dependency_event_.value(), 0), + "Failed to make internal stream wait for dependency event"); + + return internal_stream_.value(); + } else { + // Use the current PyTorch XPU stream for synchronous operations + return xpu_api_->getCurrentXPUStream(device_.index()); + } +} + +void TorchCommXCCL::ensureTensorContiguous(const at::Tensor& tensor) { + if (!tensor.is_contiguous()) { + throw std::runtime_error("Tensor must be contiguous for XCCL operations"); + } +} + +// Protected methods (not in the private section of the header) +xpuEvent_t TorchCommXCCL::getEvent() { + std::lock_guard lock(event_pool_mutex_); + + if (!event_pool_.empty()) { + xpuEvent_t event = std::move(event_pool_.front()); + event_pool_.pop(); + return event; + } + + // Create new event if pool is empty + xpuEvent_t event; + XPU_CHECK( + xpu_api_, + xpu_api_->eventCreateWithFlags(event, /*flags=*/0), + "Failed to create event"); + return event; +} + +void TorchCommXCCL::returnEvent(xpuEvent_t&& event) { + std::lock_guard lock(event_pool_mutex_); + + if (event_pool_.size() < max_event_pool_size_) { + event_pool_.push(std::move(event)); + } else { + // Pool is full, destroy the event + XPU_CHECK( + xpu_api_, xpu_api_->eventDestroy(event), "Failed to destroy event"); + } +} + +void TorchCommXCCL::attachMemoryHook() { + // NOTE: CachingAllocatorHook is not implemented for XPU/SYCL yet + // TODO: Implement XPU caching allocator hook when available + // CachingAllocatorHook::getInstance().registerComm(this); +} + +void TorchCommXCCL::detachMemoryHook() { + // NOTE: CachingAllocatorHook is not implemented for XPU/SYCL yet + // TODO: Implement XPU caching allocator hook when available + // CachingAllocatorHook::getInstance().deregisterComm(this); +} + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchWorkXCCL.cpp b/comms/torchcomms/xccl/TorchWorkXCCL.cpp new file mode 100644 index 00000000..7fd99c80 --- /dev/null +++ b/comms/torchcomms/xccl/TorchWorkXCCL.cpp @@ -0,0 +1,142 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/torchcomms/xccl/TorchWorkXCCL.hpp" +#include +#include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" + +namespace torch { +namespace comms { + +TorchWorkXCCL::TorchWorkXCCL( + std::shared_ptr comm, + xpuStream_t stream, + std::chrono::milliseconds timeout_ms, + const std::vector& inputTensors, + std::shared_ptr tracing) + : inputTensors_(inputTensors), + comm_(std::move(comm)), + stream_(stream), + timeout_ms_(timeout_ms), + state_(WorkStatus::NOT_STARTED), + tracing_(std::move(tracing)) { + // If not in graph capture mode, create the events for start and end + // recording + start_event_ = comm_->getEvent(); + end_event_ = comm_->getEvent(); + + // Events will be recorded around the actual XCCL operations +} + +TorchWorkXCCL::~TorchWorkXCCL() { + if (!comm_) { + return; + } + // If not in graph capture mode, return the events to the pool + comm_->returnEvent(std::move(start_event_)); + comm_->returnEvent(std::move(end_event_)); +} + +void TorchWorkXCCL::recordStart() { + XPU_CHECK( + comm_->getXpuApi(), + comm_->getXpuApi()->eventRecord(start_event_, stream_), + "Failed to record start event"); +} + +void TorchWorkXCCL::recordEnd() { + XPU_CHECK( + comm_->getXpuApi(), + comm_->getXpuApi()->eventRecord(end_event_, stream_), + "Failed to record end event"); +} + +bool TorchWorkXCCL::isCompleted() { + return state_ == WorkStatus::COMPLETED; +} + +TorchWorkXCCL::WorkStatus TorchWorkXCCL::checkStatus() { + // If already marked as completed, return COMPLETED + if (state_ == WorkStatus::COMPLETED || state_ == WorkStatus::ERROR || + state_ == WorkStatus::TIMEDOUT) { + return state_; + } + + // Step 1: If start_completed_time_ doesn't have a value yet, query the start + // event + if (!start_completed_time_.has_value()) { + xpu_result_t start_status = comm_->getXpuApi()->eventQuery(start_event_); + + if (start_status == XPU_SUCCESS) { + // Start event has completed, store the current time + start_completed_time_ = std::chrono::steady_clock::now(); + state_ = WorkStatus::INPROGRESS; + } else if ( + start_status != XPU_ERROR_NOT_READY && + start_status != XPU_ERROR_UNSUPPORTED) { + // Some other error occurred with the start event + TC_LOG(ERROR) << "XPU error during start event query: " + << comm_->getXpuApi()->getErrorString(start_status) << " (" + << start_status << ")"; + state_ = WorkStatus::ERROR; + } + } + if (state_ == WorkStatus::NOT_STARTED || state_ == WorkStatus::ERROR) { + return state_; + } + + // Step 2: If we get here, start event has completed, so query the end event + xpu_result_t end_status = comm_->getXpuApi()->eventQuery(end_event_); + + if (end_status == XPU_SUCCESS) { + // End event has completed, mark the work as completed + state_ = WorkStatus::COMPLETED; + + // Release the input tensors to keep the lifetime of the tensors short + inputTensors_.clear(); + } else if (end_status == XPU_ERROR_NOT_READY) { + // End event has not completed yet, check for timeout + auto current_time = std::chrono::steady_clock::now(); + auto elapsed_milliseconds = + std::chrono::duration_cast( + current_time - start_completed_time_.value()); + + // Check if the operation has timed out + if (elapsed_milliseconds > timeout_ms_) { + // Operation has timed out + state_ = WorkStatus::TIMEDOUT; + } + } else if (end_status != XPU_ERROR_UNSUPPORTED) { + // Some other error occurred with the end event + TC_LOG(ERROR) << "XPU error during end event query: " + << comm_->getXpuApi()->getErrorString(end_status) << " (" + << end_status << ")"; + state_ = WorkStatus::ERROR; + } + return state_; +} + +void TorchWorkXCCL::wait() { + // If already completed, return immediately + WorkStatus local_state = state_; + if (local_state == WorkStatus::COMPLETED || + local_state == WorkStatus::ERROR || local_state == WorkStatus::TIMEDOUT) { + return; + } + + tracing_->recordEvent("wait"); + + // Get the current stream using the device from the comm object + xpuStream_t current_stream = + comm_->getXpuApi()->getCurrentXPUStream(comm_->device_.index()); + + // Add a dependency from the work's stream to the current stream + // This makes the current stream wait for the end_event_ recorded on the + // work's stream + XPU_CHECK( + comm_->getXpuApi(), + comm_->getXpuApi()->streamWaitEvent(current_stream, end_event_, 0), + "Failed to make stream wait for event"); +} +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchWorkXCCL.hpp b/comms/torchcomms/xccl/TorchWorkXCCL.hpp new file mode 100644 index 00000000..0c0e9fbc --- /dev/null +++ b/comms/torchcomms/xccl/TorchWorkXCCL.hpp @@ -0,0 +1,108 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +// #include // @manual=third-party//xpu:xpu-lazy +#include "comms/torchcomms/TorchCommTracing.hpp" +#include "comms/torchcomms/TorchWork.hpp" +#include "comms/torchcomms/device/XpuApi.hpp" + +namespace torch { +namespace comms { + +// Forward declaration +class TorchCommXCCL; +class TorchCommWindowXCCL; + +// Forward declaration for test class +namespace test { +class TorchCommXCCLTest; +} + +class TorchWorkXCCL : public TorchWork { + public: + // Status of a work object + enum class WorkStatus { + NOT_STARTED, // Work has not started yet + INPROGRESS, // Work is still in progress, + COMPLETED, // Work has completed successfully + TIMEDOUT, // Work has timed out + ERROR // Work has encountered an error + }; + + TorchWorkXCCL( + std::shared_ptr comm, + xpuStream_t stream, + std::chrono::milliseconds timeout_ms, + const std::vector& inputTensors, + std::shared_ptr tracing); + ~TorchWorkXCCL() override; + + // Delete copy and move operations + TorchWorkXCCL(const TorchWorkXCCL&) = delete; + TorchWorkXCCL(TorchWorkXCCL&&) = delete; + TorchWorkXCCL& operator=(const TorchWorkXCCL&) = delete; + TorchWorkXCCL& operator=(TorchWorkXCCL&&) = delete; + + // Override virtual functions from TorchWork + bool isCompleted() override; + void wait() override; + + protected: + void recordStart(); + void recordEnd(); + + friend class TorchCommXCCL; + friend class TorchWorkXCCLQueue; + + private: + // Check the status of the work object + WorkStatus checkStatus(); + + std::chrono::milliseconds getTimeout() const { + return timeout_ms_; + } + std::vector inputTensors_; + + std::shared_ptr comm_; + xpuEvent_t start_event_; + xpuEvent_t end_event_; + xpuStream_t stream_; // stream is not owned by this class + + std::chrono::milliseconds timeout_ms_; + + // state machine variables. TODO: convert to state machine later + std::atomic state_; + + std::optional start_completed_time_; + std::shared_ptr tracing_; +}; + +class TorchWorkXCCLQueue { + public: + TorchWorkXCCLQueue() = default; + ~TorchWorkXCCLQueue() = default; + + TorchWorkXCCL::WorkStatus garbageCollect(bool isMainThread); + // Finalize function can only be called from the main thread + TorchWorkXCCL::WorkStatus finalize(); + void enqueueWork(std::shared_ptr work, xpuStream_t stream); + + private: + std::unordered_map>> + stream_work_queues_; + std::vector> completed_work_queue_; + std::recursive_mutex work_queues_mutex_; +}; + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp b/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp new file mode 100644 index 00000000..d1d0267c --- /dev/null +++ b/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp @@ -0,0 +1,100 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/torchcomms/xccl/TorchWorkXCCL.hpp" + +namespace torch { +namespace comms { + +TorchWorkXCCL::WorkStatus TorchWorkXCCLQueue::garbageCollect( + bool isMainThread) { + std::lock_guard lock(work_queues_mutex_); + + TorchWorkXCCL::WorkStatus last_status = TorchWorkXCCL::WorkStatus::COMPLETED; + + // Keep popping completed elements until we hit an in-progress element + // or the queue is empty + // Use an iterator to safely remove empty queues while iterating + auto it = stream_work_queues_.begin(); + while (it != stream_work_queues_.end()) { + auto& work_queue = it->second; + + while (!work_queue.empty()) { + // Get the first work object in the queue + auto work = work_queue.front(); + + // Use the checkStatus function to determine the work status + TorchWorkXCCL::WorkStatus status = work->checkStatus(); + last_status = status; + + if (status == TorchWorkXCCL::WorkStatus::COMPLETED) { + // Work is completed, remove it from the work queue + work_queue.pop(); + completed_work_queue_.push_back(work); + // Continue to the next element in the queue + } else if ( + status == TorchWorkXCCL::WorkStatus::TIMEDOUT || + status == TorchWorkXCCL::WorkStatus::ERROR) { + // Return the error status immediately + return status; + } else { + // NOT_STARTED or INPROGRESS - stop processing this queue + break; + } + } + + // If the queue is now empty, remove it from the map + if (work_queue.empty()) { + it = stream_work_queues_.erase(it); + } else { + ++it; + } + } + + if (isMainThread) { + // If we are the main thread, clear the completed work queues + completed_work_queue_.clear(); + } + + return last_status; +} + +TorchWorkXCCL::WorkStatus TorchWorkXCCLQueue::finalize() { + // Because this function is typically called after the timeout thread has + // already joined, we might not need to lock here. But doing the lock anyway, + // as defensive programming, just in case someone moves the thread join order + // later. The cost of the lock itself should be small on modern linux systems + // (uncontended locks are typically just an atomic operation). + std::lock_guard lock(work_queues_mutex_); + + // Initialize the status to COMPLETED to cover the case where the queue is + // empty + TorchWorkXCCL::WorkStatus status = TorchWorkXCCL::WorkStatus::COMPLETED; + while (!stream_work_queues_.empty()) { + status = garbageCollect(true); + if (status == TorchWorkXCCL::WorkStatus::ERROR || + status == TorchWorkXCCL::WorkStatus::TIMEDOUT || + status == TorchWorkXCCL::WorkStatus::COMPLETED) { + break; + } + } + + // Clear all work queues & completed work queue. + // + // NOTE: finalize MUST return without holding references to any work object, + // otherwise it may leak object and cause side effects. + stream_work_queues_.clear(); + completed_work_queue_.clear(); + + return status; +} + +void TorchWorkXCCLQueue::enqueueWork( + std::shared_ptr work, + xpuStream_t stream) { + // Add work to stream's queue after events have been recorded + std::lock_guard lock(work_queues_mutex_); + stream_work_queues_[stream].push(work); +} + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/XcclApi.cpp b/comms/torchcomms/xccl/XcclApi.cpp new file mode 100644 index 00000000..96d59bbc --- /dev/null +++ b/comms/torchcomms/xccl/XcclApi.cpp @@ -0,0 +1,197 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include "comms/torchcomms/xccl/XcclApi.hpp" +#include "comms/torchcomms/TorchCommLogging.hpp" + +namespace torch { +namespace comms { + +// DefaultXcclApi implementation +const char* DefaultXcclApi::getErrorString(onecclResult_t result) { + return onecclGetErrorString(result); +} + +onecclResult_t DefaultXcclApi::setDevice(int device) { + return onecclSetDevice(device); +} + +onecclResult_t DefaultXcclApi::getUniqueId(onecclUniqueId* uniqueId) { + return onecclGetUniqueId(uniqueId); +} + +onecclResult_t DefaultXcclApi::commInitRankConfig( + onecclComm_t* comm, + int nranks, + onecclUniqueId commId, + int rank, + onecclConfig_t* config) { + return onecclCommInitRankConfig(comm, nranks, commId, rank, config); +} + +onecclResult_t DefaultXcclApi::commDestroy(onecclComm_t comm) { + return onecclCommDestroy(comm); +} + +onecclResult_t DefaultXcclApi::commAbort(onecclComm_t comm) { + // return onecclCommAbort(comm); + return onecclNotImplemented; +} + +onecclResult_t DefaultXcclApi::commGetAsyncError( + onecclComm_t comm, + onecclResult_t* asyncError) { + // return onecclCommGetAsyncError(comm); + return onecclNotImplemented; +} + +onecclResult_t DefaultXcclApi::commSplit( + onecclComm_t comm, + int color, + int key, + onecclComm_t* newcomm, + onecclConfig_t* config) { + return onecclCommSplit(comm, color, key, newcomm, config); +} + +onecclResult_t DefaultXcclApi::commRegister( + onecclComm_t comm, + void* buffer, + size_t size, + void** handle) { + // return onecclCommRegister(comm, buffer, size, handle); + return onecclNotImplemented; +} + +onecclResult_t DefaultXcclApi::commDeregister(onecclComm_t comm, void* handle) { + // return onecclCommDeregister(comm, handle); + return onecclNotImplemented; +} + +onecclResult_t DefaultXcclApi::send( + const void* sendbuff, + size_t count, + onecclDataType_t datatype, + int peer, + onecclComm_t comm, + xpuStream_t stream) { + return onecclSend(const_cast(sendbuff), count, datatype, peer, comm, stream); +} + +onecclResult_t DefaultXcclApi::recv( + void* recvbuff, + size_t count, + onecclDataType_t datatype, + int peer, + onecclComm_t comm, + xpuStream_t stream) { + return onecclRecv(recvbuff, count, datatype, peer, comm, stream); +} + +onecclResult_t DefaultXcclApi::broadcast( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + int root, + onecclComm_t comm, + xpuStream_t stream) { + return onecclBroadcast(const_cast(sendbuff), recvbuff, count, datatype, root, comm, stream); +} + +onecclResult_t DefaultXcclApi::bcast( + void* buff, + size_t count, + onecclDataType_t datatype, + int root, + onecclComm_t comm, + xpuStream_t stream) { + return onecclBroadcast(buff, buff, count, datatype, root, comm, stream); +} + +onecclResult_t DefaultXcclApi::allReduce( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + onecclRedOp_t op, + onecclComm_t comm, + xpuStream_t stream) { + return onecclAllReduce(const_cast(sendbuff), recvbuff, count, datatype, op, comm, stream); +} + +onecclResult_t DefaultXcclApi::reduce( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + onecclRedOp_t op, + int root, + onecclComm_t comm, + xpuStream_t stream) { + return onecclReduce( + const_cast(sendbuff), recvbuff, count, datatype, op, root, comm, stream); +} + +onecclResult_t DefaultXcclApi::allGather( + const void* sendbuff, + void* recvbuff, + size_t sendcount, + onecclDataType_t datatype, + onecclComm_t comm, + xpuStream_t stream) { + return onecclAllGather(const_cast(sendbuff), recvbuff, sendcount, datatype, comm, stream); +} + +onecclResult_t DefaultXcclApi::reduceScatter( + const void* sendbuff, + void* recvbuff, + size_t recvcount, + onecclDataType_t datatype, + onecclRedOp_t op, + onecclComm_t comm, + xpuStream_t stream) { + return onecclReduceScatter( + const_cast(sendbuff), recvbuff, recvcount, datatype, op, comm, stream); +} + +onecclResult_t DefaultXcclApi::allToAll( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + onecclComm_t comm, + xpuStream_t stream) { + return onecclAllToAll(const_cast(sendbuff), recvbuff, count, datatype, comm, stream); +} + +onecclResult_t DefaultXcclApi::groupStart() { + return onecclGroupStart(); +} + +onecclResult_t DefaultXcclApi::groupEnd() { + return onecclGroupEnd(); +} + +onecclResult_t DefaultXcclApi::commUserRank(const onecclComm_t comm, int* myRank) { + return onecclCommUserRank(comm, myRank); +} + +onecclResult_t DefaultXcclApi::commCount(const onecclComm_t comm, int* count) { + return onecclCommCount(comm, count); +} + +onecclResult_t DefaultXcclApi::redOpCreatePreMulSum( + onecclRedOp_t* op, + void* scalar, + onecclDataType_t datatype, + onecclScalarResidence_t residence, + onecclComm_t comm) { + return onecclRedOpCreatePreMulSum(op, scalar, datatype, residence, comm); +} + +onecclResult_t DefaultXcclApi::redOpDestroy(onecclRedOp_t op, onecclComm_t comm) { + return onecclRedOpDestroy(op, comm); +} + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/XcclApi.hpp b/comms/torchcomms/xccl/XcclApi.hpp new file mode 100644 index 00000000..a6e0b577 --- /dev/null +++ b/comms/torchcomms/xccl/XcclApi.hpp @@ -0,0 +1,289 @@ +#pragma once + +#include +#include + +#include "comms/torchcomms/device/XpuApi.hpp" + +namespace torch { +namespace comms { + +class XcclApi { + public: + virtual ~XcclApi() = default; + + virtual const char* getErrorString(onecclResult_t result) = 0; + + // Device management + virtual onecclResult_t setDevice(int device) = 0; + + // Unique ID generation + virtual onecclResult_t getUniqueId(onecclUniqueId* uniqueId) = 0; + + // Communicator management + virtual onecclResult_t commInitRankConfig( + onecclComm_t* comm, + int nranks, + onecclUniqueId commId, + int rank, + onecclConfig_t* config) = 0; + + virtual onecclResult_t commDestroy(onecclComm_t comm) = 0; + + virtual onecclResult_t commAbort(onecclComm_t comm) = 0; + + virtual onecclResult_t commGetAsyncError( + onecclComm_t comm, + onecclResult_t* asyncError) = 0; + + virtual onecclResult_t commSplit( + onecclComm_t comm, + int color, + int key, + onecclComm_t* newcomm, + onecclConfig_t* config) = 0; + + // Memory registration + virtual onecclResult_t + commRegister(onecclComm_t comm, void* buffer, size_t size, void** handle) = 0; + + virtual onecclResult_t commDeregister(onecclComm_t comm, void* handle) = 0; + + // Point-to-point operations + virtual onecclResult_t send( + const void* sendbuff, + size_t count, + onecclDataType_t datatype, + int peer, + onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t recv( + void* recvbuff, + size_t count, + onecclDataType_t datatype, + int peer, + onecclComm_t comm, + xpuStream_t stream) = 0; + + // Collective operations + virtual onecclResult_t broadcast( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + int root, + onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t bcast( + void* buff, + size_t count, + onecclDataType_t datatype, + int root, + onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t allReduce( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + onecclRedOp_t op, + onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t reduce( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + onecclRedOp_t op, + int root, + onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t allGather( + const void* sendbuff, + void* recvbuff, + size_t sendcount, + onecclDataType_t datatype, + onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t reduceScatter( + const void* sendbuff, + void* recvbuff, + size_t recvcount, + onecclDataType_t datatype, + onecclRedOp_t op, + onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t allToAll( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + onecclComm_t comm, + xpuStream_t stream) = 0; + + // Group operations + virtual onecclResult_t groupStart() = 0; + virtual onecclResult_t groupEnd() = 0; + + virtual onecclResult_t commUserRank(const onecclComm_t comm, int* userRank) = 0; + virtual onecclResult_t commCount(const onecclComm_t comm, int* count) = 0; + + virtual onecclResult_t redOpCreatePreMulSum( + onecclRedOp_t* op, + void* scalar, + onecclDataType_t datatype, + onecclScalarResidence_t residence, + onecclComm_t comm) = 0; + virtual onecclResult_t redOpDestroy(onecclRedOp_t op, onecclComm_t comm) = 0; +}; + +/** + * Default implementation that calls the underlying XCCL APIs directly. + */ +class DefaultXcclApi : public XcclApi { + public: + ~DefaultXcclApi() override = default; + + // Error handling + const char* getErrorString(onecclResult_t result) override; + + // Device management + onecclResult_t setDevice(int device) override; + + // Unique ID generation + onecclResult_t getUniqueId(onecclUniqueId* uniqueId) override; + + // Communicator management + onecclResult_t commInitRankConfig( + onecclComm_t* comm, + int nranks, + onecclUniqueId commId, + int rank, + onecclConfig_t* config) override; + + onecclResult_t commDestroy(onecclComm_t comm) override; + + onecclResult_t commAbort(onecclComm_t comm) override; + + onecclResult_t commGetAsyncError(onecclComm_t comm, onecclResult_t* asyncError) + override; + + onecclResult_t commSplit( + onecclComm_t comm, + int color, + int key, + onecclComm_t* newcomm, + onecclConfig_t* config) override; + + onecclResult_t commRegister( + onecclComm_t comm, + void* buffer, + size_t size, + void** handle) override; + + onecclResult_t commDeregister(onecclComm_t comm, void* handle) override; + + // Point-to-point operations + onecclResult_t send( + const void* sendbuff, + size_t count, + onecclDataType_t datatype, + int peer, + onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t recv( + void* recvbuff, + size_t count, + onecclDataType_t datatype, + int peer, + onecclComm_t comm, + xpuStream_t stream) override; + + // Collective operations + onecclResult_t broadcast( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + int root, + onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t bcast( + void* buff, + size_t count, + onecclDataType_t datatype, + int root, + onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t allReduce( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + onecclRedOp_t op, + onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t reduce( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + onecclRedOp_t op, + int root, + onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t allGather( + const void* sendbuff, + void* recvbuff, + size_t sendcount, + onecclDataType_t datatype, + onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t reduceScatter( + const void* sendbuff, + void* recvbuff, + size_t recvcount, + onecclDataType_t datatype, + onecclRedOp_t op, + onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t allToAll( + const void* sendbuff, + void* recvbuff, + size_t count, + onecclDataType_t datatype, + onecclComm_t comm, + xpuStream_t stream) override; + + // Group operations + onecclResult_t groupStart() override; + onecclResult_t groupEnd() override; + + onecclResult_t commUserRank(const onecclComm_t comm, int* userRank) override; + onecclResult_t commCount(const onecclComm_t comm, int* count) override; + + onecclResult_t redOpCreatePreMulSum( + onecclRedOp_t* op, + void* scalar, + onecclDataType_t datatype, + onecclScalarResidence_t residence, + onecclComm_t comm) override; + onecclResult_t redOpDestroy(onecclRedOp_t op, onecclComm_t comm) override; +}; + +} // namespace comms +} // namespace torch diff --git a/setup.py b/setup.py index dc35dda9..648fd3d8 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ def flag_str(val: bool): USE_NCCLX = flag_enabled("USE_NCCLX", True) USE_GLOO = flag_enabled("USE_GLOO", True) USE_RCCL = flag_enabled("USE_RCCL", False) +USE_XCCL = flag_enabled("USE_XCCL", False) requirement_path = os.path.join(ROOT, "requirements.txt") try: @@ -105,6 +106,7 @@ def build_cmake(self, ext): f"-DUSE_NCCLX={flag_str(USE_NCCLX)}", f"-DUSE_GLOO={flag_str(USE_GLOO)}", f"-DUSE_RCCL={flag_str(USE_RCCL)}", + f"-DUSE_XCCL={flag_str(USE_XCCL)}", ] build_args = ["--", "-j"] @@ -145,6 +147,10 @@ def build_cmake(self, ext): ext_modules += [ CMakeExtension("torchcomms._comms_rccl"), ] +if USE_XCCL: + ext_modules += [ + CMakeExtension("torchcomms._comms_xccl"), + ] setup( name="torchcomms", @@ -157,6 +163,7 @@ def build_cmake(self, ext): "ncclx = torchcomms._comms_ncclx", "gloo = torchcomms._comms_gloo", "rccl = torchcomms._comms_rccl", + "xccl = torchcomms._comms_xccl", ] }, ext_modules=ext_modules, From 9d142e417fd3465b190b8b78047e6c7e8e6e6d0b Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 3 Nov 2025 15:19:00 +0800 Subject: [PATCH 02/10] comments and format --- comms/torchcomms/device/XpuApi.cpp | 44 +- comms/torchcomms/device/XpuApi.hpp | 7 - comms/torchcomms/xccl/CMakeLists.txt | 2 - comms/torchcomms/xccl/TorchCommXCCL.cpp | 752 ++++++------------ comms/torchcomms/xccl/TorchCommXCCL.hpp | 283 +++---- .../xccl/TorchCommXCCLBootstrap.cpp | 171 ++-- .../xccl/TorchCommXCCLBootstrap.hpp | 53 +- comms/torchcomms/xccl/TorchCommXCCLPy.cpp | 2 - comms/torchcomms/xccl/TorchCommXCCLUtils.cpp | 377 +++------ comms/torchcomms/xccl/TorchWorkXCCL.cpp | 52 +- comms/torchcomms/xccl/TorchWorkXCCL.hpp | 43 +- comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp | 18 +- comms/torchcomms/xccl/XcclApi.cpp | 206 ++--- comms/torchcomms/xccl/XcclApi.hpp | 328 +++----- 14 files changed, 814 insertions(+), 1524 deletions(-) diff --git a/comms/torchcomms/device/XpuApi.cpp b/comms/torchcomms/device/XpuApi.cpp index d6d594cd..18f80143 100644 --- a/comms/torchcomms/device/XpuApi.cpp +++ b/comms/torchcomms/device/XpuApi.cpp @@ -1,5 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #include "comms/torchcomms/device/XpuApi.hpp" #include #include @@ -10,10 +8,6 @@ namespace torch { namespace comms { -// ============================================================================ -// Device Management Implementation -// ============================================================================ - xpu_result_t DefaultXpuApi::setDevice(int device) { try { ::c10::xpu::set_device(device); @@ -29,20 +23,21 @@ xpu_result_t DefaultXpuApi::getDeviceProperties(xpuDeviceProp* prop, int device) } try { - ::c10::xpu::DeviceProp device_prop; - ::c10::xpu::get_device_properties(&device_prop, device); + sycl::device sycl_device = ::c10::xpu::get_raw_device(device); - // Map c10::xpu::DeviceProp to xpuDeviceProp - strncpy(prop->name, device_prop.name.c_str(), 255); + // Get device name + std::string device_name = sycl_device.get_info(); + strncpy(prop->name, device_name.c_str(), 255); prop->name[255] = '\0'; - prop->totalGlobalMem = device_prop.global_mem_size; - // Extract major/minor version from device_id - prop->major = (device_prop.device_id >> 16) & 0xFFFF; - prop->minor = device_prop.device_id & 0xFFFF; + // Get memory info + prop->totalGlobalMem = sycl_device.get_info(); - // Get SYCL device for additional properties - sycl::device& sycl_device = ::c10::xpu::get_raw_device(device); + // Set version info (XPU doesn't have major/minor version like CUDA) + prop->major = 1; + prop->minor = 0; + + // Get compute capabilities auto max_work_group_size = sycl_device.get_info(); auto max_work_item_sizes = sycl_device.get_info>(); auto max_compute_units = sycl_device.get_info(); @@ -95,10 +90,6 @@ xpu_result_t DefaultXpuApi::getDeviceCount(int* count) { } } -// ============================================================================ -// Stream Management Implementation -// ============================================================================ - xpu_result_t DefaultXpuApi::streamCreateWithPriority( xpuStream_t& stream, unsigned int flags, @@ -170,10 +161,6 @@ xpu_result_t DefaultXpuApi::streamGetCaptureInfo( return XPU_SUCCESS; } -// ============================================================================ -// Memory Management Implementation -// ============================================================================ - xpu_result_t DefaultXpuApi::malloc(void** devPtr, size_t size) { if (!devPtr) { return XPU_ERROR_INVALID_VALUE; @@ -237,9 +224,6 @@ xpu_result_t DefaultXpuApi::memcpyAsync( } } -// ============================================================================ -// Event Management Implementation -// ============================================================================ xpu_result_t DefaultXpuApi::eventCreate(xpuEvent_t& event) { try { @@ -285,10 +269,7 @@ xpu_result_t DefaultXpuApi::eventQuery(const xpuEvent_t& event) { } } -// ============================================================================ // Graph Operations (Unsupported) -// ============================================================================ - xpu_result_t DefaultXpuApi::userObjectCreate( xpuUserObject_t* object_out, void* ptr, @@ -334,10 +315,7 @@ xpu_result_t DefaultXpuApi::streamGetCaptureInfo_v2( return XPU_SUCCESS; } -// ============================================================================ // Error Handling -// ============================================================================ - const char* DefaultXpuApi::getErrorString(xpu_result_t error) { switch (error) { case XPU_SUCCESS: diff --git a/comms/torchcomms/device/XpuApi.hpp b/comms/torchcomms/device/XpuApi.hpp index f6dee9c7..1cc14e8b 100644 --- a/comms/torchcomms/device/XpuApi.hpp +++ b/comms/torchcomms/device/XpuApi.hpp @@ -1,5 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #pragma once #include @@ -11,11 +9,9 @@ namespace torch { namespace comms { -// Use PyTorch XPU types directly as value types using xpuStream_t = ::c10::xpu::XPUStream; using xpuEvent_t = ::at::xpu::XPUEvent; -// Device properties structure - mapped from SYCL struct xpuDeviceProp { char name[256]; size_t totalGlobalMem; @@ -135,9 +131,6 @@ class XpuApi { virtual const char* getErrorString(xpu_result_t error) = 0; }; -/** - * Default implementation that uses SYCL and PyTorch XPU APIs. - */ class DefaultXpuApi : public XpuApi { public: ~DefaultXpuApi() override = default; diff --git a/comms/torchcomms/xccl/CMakeLists.txt b/comms/torchcomms/xccl/CMakeLists.txt index dc69015f..88eef1d4 100644 --- a/comms/torchcomms/xccl/CMakeLists.txt +++ b/comms/torchcomms/xccl/CMakeLists.txt @@ -1,8 +1,6 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. # Extension: torchcomms._comms_xccl file(GLOB TORCHCOMMS_XCCL_SOURCES "comms/torchcomms/xccl/*.cpp") file(GLOB TORCHCOMMS_XPU_API_SOURCE "comms/torchcomms/device/XpuApi.cpp") -# find_package(XPU) set(XCCL_INCLUDE "$ENV{CCL_ROOT}/include") set(XCCL_SHARED_LIB "$ENV{CCL_ROOT}/lib/libccl.so.2") diff --git a/comms/torchcomms/xccl/TorchCommXCCL.cpp b/comms/torchcomms/xccl/TorchCommXCCL.cpp index 90d375e4..5fb91e49 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.cpp @@ -1,49 +1,44 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #include "comms/torchcomms/xccl/TorchCommXCCL.hpp" +#include "comms/torchcomms/TorchCommFactory.hpp" +#include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp" #include #include #include #include -#include "comms/torchcomms/TorchCommFactory.hpp" -#include "comms/torchcomms/TorchCommLogging.hpp" -#include "comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp" -// #include "xccl.h" // @manual namespace torch { namespace comms { -onecclResult_t XCCLException::getResult() const { - return result_; -} +onecclResult_t XCCLException::getResult() const { return result_; } TorchCommXCCL::TorchCommXCCL() - : xccl_comm_{nullptr}, - device_(at::kXPU), - init_state_(InitializationState::UNINITIALIZED), - shutdown_(false) {} + : xccl_comm_{nullptr}, device_(at::kXPU), + init_state_(InitializationState::UNINITIALIZED), shutdown_(false) {} TorchCommXCCL::TorchCommXCCL(const onecclComm_t xccl_comm) - : xccl_comm_(xccl_comm), - device_(at::kXPU), - init_state_(InitializationState::UNINITIALIZED), - shutdown_(false) {} + : xccl_comm_(xccl_comm), device_(at::kXPU), + init_state_(InitializationState::UNINITIALIZED), shutdown_(false) {} TorchCommXCCL::~TorchCommXCCL() { if (init_state_ == InitializationState::INITIALIZED) { TC_LOG(ERROR) << "TorchCommXCCL was not finalized before destruction"; + + // If finalize was not called, we need to clean up the timeout thread + if (timeout_thread_.joinable()) { + shutdown_.store(true); + timeout_thread_.join(); + } } - // We need to dteach the memory hook in case finalize is not called, + // We need to detach the memory hook in case finalize is not called, // so that we don't encounter a memory corruption. detachMemoryHook(); } -void TorchCommXCCL::init( - at::Device device, - const std::string& name, - const CommOptions& options) { +void TorchCommXCCL::init(at::Device device, const std::string &name, + const CommOptions &options) { // Initialize private members device_ = device; name_ = name; @@ -68,48 +63,23 @@ void TorchCommXCCL::init( } if (device_.index() == -1 || xccl_comm_ == nullptr) { - TC_LOG(INFO) << "[TC] Creating bootstrap..."; auto bootstrap = new TorchCommXCCLBootstrap( options_.store, device_, xccl_api_, xpu_api_, options_.timeout); device_ = bootstrap->getDevice(); if (xccl_comm_ == nullptr) { - TC_LOG(INFO) << "[TC] Creating XCCL communicator..."; xccl_comm_ = bootstrap->createXcclComm(name, options); - TC_LOG(INFO) << "[TC] XCCL communicator created"; } delete bootstrap; - TC_LOG(INFO) << "[TC] Bootstrap deleted"; - } - - // Set XPU device and verify it's accessible - TC_LOG(INFO) << "[TC] Setting XPU device " << device_.index(); - XPU_CHECK( - xpu_api_, - xpu_api_->setDevice(device_.index()), - "Failed to set XPU device to " + std::to_string(device_.index())); - - // Verify device properties and memory availability - TC_LOG(INFO) << "[TC] Getting device properties..."; - xpuDeviceProp device_prop = {}; - XPU_CHECK( - xpu_api_, - xpu_api_->getDeviceProperties(&device_prop, device_.index()), - "Failed to get device properties for device " + - std::to_string(device_.index())); - - // Check available memory - TC_LOG(INFO) << "[TC] Getting memory info..."; - size_t free_memory, total_memory; - XPU_CHECK( - xpu_api_, - xpu_api_->memGetInfo(&free_memory, &total_memory), - "Failed to get memory info for device " + - std::to_string(device_.index())); + } + + // Set XPU device and verify it' accessible + XPU_CHECK(xpu_api_, xpu_api_->setDevice(device_.index()), + "Failed to set XPU device to " + std::to_string(device_.index())); // Read hints and store them - for (auto const& [key, val] : options_.hints) { + for (auto const &[key, val] : options_.hints) { if (key.starts_with("torchcomm::xccl::")) { if (key == "torchcomm::xccl::high_priority_stream") { high_priority_stream_ = string_to_bool(val); @@ -130,37 +100,24 @@ void TorchCommXCCL::init( } // Initialize internal stream - // Get a temporary stream first (to have a valid object), then replace it - TC_LOG(INFO) << "[TC] Creating internal XPU stream with priority " << stream_priority; xpuStream_t temp_stream = xpu_api_->getCurrentXPUStream(device_.index()); - XPU_CHECK( - xpu_api_, - xpu_api_->streamCreateWithPriority( - temp_stream, /*flags=*/0, stream_priority), - "Failed to create internal XPU stream on device " + - std::to_string(device_.index())); + XPU_CHECK(xpu_api_, + xpu_api_->streamCreateWithPriority(temp_stream, /*flags=*/0, + stream_priority), + "Failed to create internal XPU stream on device " + + std::to_string(device_.index())); internal_stream_ = std::move(temp_stream); - TC_LOG(INFO) << "[TC] Internal stream created"; // Create dependency event for stream synchronization - TC_LOG(INFO) << "[TC] Creating dependency event"; xpuEvent_t temp_event(/*enable_timing=*/false); - XPU_CHECK( - xpu_api_, - xpu_api_->eventCreateWithFlags( - temp_event, /*flags=*/0), - "Failed to create dependency event on device " + - std::to_string(device_.index())); - TC_LOG(INFO) << "[TC] Dependency event created"; + XPU_CHECK(xpu_api_, xpu_api_->eventCreateWithFlags(temp_event, /*flags=*/0), + "Failed to create dependency event on device " + + std::to_string(device_.index())); dependency_event_ = std::move(temp_event); // Allocate XPU buffer for barrier operations - TC_LOG(INFO) << "[TC] Allocating barrier buffer"; - XPU_CHECK( - xpu_api_, - xpu_api_->malloc(&barrier_buffer_, sizeof(float)), - "Failed to allocate barrier buffer"); - TC_LOG(INFO) << "[TC] Barrier buffer allocated"; + XPU_CHECK(xpu_api_, xpu_api_->malloc(&barrier_buffer_, sizeof(float)), + "Failed to allocate barrier buffer"); if (options_.hints.contains("torchcomm::xccl::max_event_pool_size")) { max_event_pool_size_ = @@ -177,35 +134,27 @@ void TorchCommXCCL::init( options_.store.reset(); } - TC_LOG(INFO) << "[TC] Getting XCCL rank"; onecclResult_t xcclErr; xcclErr = xccl_api_->commUserRank(xccl_comm_, &rank_); if (xcclErr != onecclSuccess) { throw std::runtime_error("XCCL User Rank failed"); } - TC_LOG(INFO) << "[TC] Rank: " << rank_; tryTorchCommLoggingInit("torchcomm"); - TC_LOG(INFO) << "[TC] Getting XCCL comm size"; xcclErr = xccl_api_->commCount(xccl_comm_, &comm_size_); if (xcclErr != onecclSuccess) { throw std::runtime_error("XCCL Count failed"); } - TC_LOG(INFO) << "[TC] Comm size: " << comm_size_; - TC_LOG(INFO) << "[TC] Creating tracing object"; tracing_ = std::make_shared(name, comm_size_, rank_); tracing_->recordEvent("init"); // Start timeout watchdog thread - TC_LOG(INFO) << "[TC] Starting timeout watchdog thread"; timeout_thread_ = std::thread(&TorchCommXCCL::timeoutWatchdog, this); // Register comm with CachingAllocator - TC_LOG(INFO) << "[TC] Attaching memory hook"; attachMemoryHook(); - TC_LOG(INFO) << "[TC] TorchCommXCCL initialization completed"; } void TorchCommXCCL::finalize() { @@ -259,35 +208,29 @@ void TorchCommXCCL::finalize() { while (!event_pool_.empty()) { xpuEvent_t event = std::move(event_pool_.front()); event_pool_.pop(); - XPU_CHECK( - xpu_api_, xpu_api_->eventDestroy(event), "Failed to destroy event"); + XPU_CHECK(xpu_api_, xpu_api_->eventDestroy(event), + "Failed to destroy event"); } } // Free barrier buffer. TODO: handle errors on xpu free and stream destroy if (barrier_buffer_) { - XPU_CHECK( - xpu_api_, - xpu_api_->free(barrier_buffer_), - "Failed to free barrier buffer"); + XPU_CHECK(xpu_api_, xpu_api_->free(barrier_buffer_), + "Failed to free barrier buffer"); barrier_buffer_ = nullptr; } // Destroy dependency event if (dependency_event_.has_value()) { - XPU_CHECK( - xpu_api_, - xpu_api_->eventDestroy(dependency_event_.value()), - "Failed to destroy dependency event"); + XPU_CHECK(xpu_api_, xpu_api_->eventDestroy(dependency_event_.value()), + "Failed to destroy dependency event"); dependency_event_.reset(); } // Destroy internal stream if (internal_stream_.has_value()) { - XPU_CHECK( - xpu_api_, - xpu_api_->streamDestroy(internal_stream_.value()), - "Failed to destroy internal stream"); + XPU_CHECK(xpu_api_, xpu_api_->streamDestroy(internal_stream_.value()), + "Failed to destroy internal stream"); internal_stream_.reset(); } @@ -335,17 +278,13 @@ int TorchCommXCCL::getSize() const { return comm_size; } -std::string_view TorchCommXCCL::getBackendName() const { - return kBackendName; -} +std::string_view TorchCommXCCL::getBackendName() const { return kBackendName; } -std::string_view TorchCommXCCL::getCommName() const { - return name_; -} +std::string_view TorchCommXCCL::getCommName() const { return name_; } -static inline std::chrono::milliseconds getOperationTimeout( - std::chrono::milliseconds timeout, - std::chrono::milliseconds default_timeout) { +static inline std::chrono::milliseconds +getOperationTimeout(std::chrono::milliseconds timeout, + std::chrono::milliseconds default_timeout) { // If timeout is kNoTimeout (0ms), use the default timeout from options if (timeout == kNoTimeout) { return default_timeout; @@ -354,11 +293,9 @@ static inline std::chrono::milliseconds getOperationTimeout( } // Point-to-Point Operations -std::shared_ptr TorchCommXCCL::send( - const at::Tensor& tensor, - int dst, - bool async_op, - const SendOptions& options) { +std::shared_ptr TorchCommXCCL::send(const at::Tensor &tensor, + int dst, bool async_op, + const SendOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(tensor); @@ -369,35 +306,26 @@ std::shared_ptr TorchCommXCCL::send( auto work = createWork( stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); - // Record start event before XCCL operation work->recordStart(); - onecclResult_t result = xccl_api_->send( - tensor.data_ptr(), - tensor.numel(), - getXcclDataType(tensor), - dst, - xccl_comm_, - stream); + onecclResult_t result = + xccl_api_->send(tensor.data_ptr(), tensor.numel(), + getXcclDataType(tensor), dst, xccl_comm_, stream); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL Send failed", result); } - // Record end event after XCCL operation work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::recv( - at::Tensor& tensor, - int src, - bool async_op, - const RecvOptions& options) { +std::shared_ptr TorchCommXCCL::recv(at::Tensor &tensor, int src, + bool async_op, + const RecvOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(tensor); @@ -408,35 +336,27 @@ std::shared_ptr TorchCommXCCL::recv( auto work = createWork( stream, getOperationTimeout(options.timeout, options_.timeout), {}); - // Record start event before XCCL operation work->recordStart(); - onecclResult_t result = xccl_api_->recv( - tensor.data_ptr(), - tensor.numel(), - getXcclDataType(tensor), - src, - xccl_comm_, - stream); + onecclResult_t result = + xccl_api_->recv(tensor.data_ptr(), tensor.numel(), + getXcclDataType(tensor), src, xccl_comm_, stream); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL Recv failed", result); } - // Record end event after XCCL operation work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } // Batch P2P Operations -std::shared_ptr TorchCommXCCL::batch_op_issue( - const std::vector& ops, - bool async_op, - const BatchP2POptions& options) { +std::shared_ptr +TorchCommXCCL::batch_op_issue(const std::vector &ops, + bool async_op, const BatchP2POptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); @@ -448,7 +368,7 @@ std::shared_ptr TorchCommXCCL::batch_op_issue( std::vector input_tensors; std::vector output_tensors; - for (const auto& op : ops) { + for (const auto &op : ops) { if (op.type == BatchSendRecv::P2POp::OpType::SEND) { at::Tensor tensor = op.tensor; ensureTensorContiguous(tensor); @@ -462,16 +382,14 @@ std::shared_ptr TorchCommXCCL::batch_op_issue( } } - tracing_->recordEventWithInputOutput( - "batch_op_issue", rank_, input_tensors, output_tensors); + tracing_->recordEventWithInputOutput("batch_op_issue", rank_, input_tensors, + output_tensors); xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, - getOperationTimeout(options.timeout, options_.timeout), - input_tensors); + auto work = + createWork(stream, getOperationTimeout(options.timeout, options_.timeout), + input_tensors); - // Record start event before XCCL operations work->recordStart(); // Start XCCL group for batched operations @@ -481,59 +399,46 @@ std::shared_ptr TorchCommXCCL::batch_op_issue( } // Issue each operation individually - for (const auto& op : ops) { + for (const auto &op : ops) { if (op.type == BatchSendRecv::P2POp::OpType::SEND) { - result = xccl_api_->send( - op.tensor.data_ptr(), - op.tensor.numel(), - getXcclDataType(op.tensor), - op.peer, - xccl_comm_, - stream); + result = xccl_api_->send(op.tensor.data_ptr(), op.tensor.numel(), + getXcclDataType(op.tensor), op.peer, xccl_comm_, + stream); if (result != onecclSuccess) { xccl_api_->groupEnd(); // Clean up group on error - throw XCCLException( - *xccl_api_, "XCCL Send failed in batch operation", result); + throw XCCLException(*xccl_api_, "XCCL Send failed in batch operation", + result); } } else if (op.type == BatchSendRecv::P2POp::OpType::RECV) { - result = xccl_api_->recv( - op.tensor.data_ptr(), - op.tensor.numel(), - getXcclDataType(op.tensor), - op.peer, - xccl_comm_, - stream); + result = xccl_api_->recv(op.tensor.data_ptr(), op.tensor.numel(), + getXcclDataType(op.tensor), op.peer, xccl_comm_, + stream); if (result != onecclSuccess) { xccl_api_->groupEnd(); // Clean up group on error - throw XCCLException( - *xccl_api_, "XCCL Recv failed in batch operation", result); + throw XCCLException(*xccl_api_, "XCCL Recv failed in batch operation", + result); } } } - // End XCCL group result = xccl_api_->groupEnd(); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL GroupEnd failed", result); } - // Record end event after XCCL operations work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } // Collective Operations -std::shared_ptr TorchCommXCCL::broadcast( - at::Tensor& tensor, - int root, - bool async_op, - const BroadcastOptions& options) { +std::shared_ptr +TorchCommXCCL::broadcast(at::Tensor &tensor, int root, bool async_op, + const BroadcastOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(tensor); @@ -545,35 +450,26 @@ std::shared_ptr TorchCommXCCL::broadcast( auto work = createWork( stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); - // Record start event before XCCL operation work->recordStart(); - onecclResult_t result = xccl_api_->bcast( - tensor.data_ptr(), - tensor.numel(), - getXcclDataType(tensor), - root, - xccl_comm_, - stream); + onecclResult_t result = + xccl_api_->bcast(tensor.data_ptr(), tensor.numel(), + getXcclDataType(tensor), root, xccl_comm_, stream); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL Broadcast failed", result); } - // Record end event after XCCL operation work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::all_reduce( - at::Tensor& tensor, - ReduceOp op, - bool async_op, - const AllReduceOptions& options) { +std::shared_ptr +TorchCommXCCL::all_reduce(at::Tensor &tensor, ReduceOp op, bool async_op, + const AllReduceOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(tensor); @@ -584,38 +480,30 @@ std::shared_ptr TorchCommXCCL::all_reduce( auto work = createWork( stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); - // Record start event before XCCL operation work->recordStart(); const auto dataType = getXcclDataType(tensor); onecclResult_t result = xccl_api_->allReduce( tensor.data_ptr(), tensor.data_ptr(), // In-place operation - tensor.numel(), - dataType, - getXcclReduceOp(op, xccl_comm_, dataType), - xccl_comm_, - stream); + tensor.numel(), dataType, getXcclReduceOp(op, xccl_comm_, dataType), + xccl_comm_, stream); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL AllReduce failed", result); } - // Record end event after XCCL operation work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::reduce( - const at::Tensor& tensor, - int root, - ReduceOp op, - bool async_op, - const ReduceOptions& options) { +std::shared_ptr TorchCommXCCL::reduce(const at::Tensor &tensor, + int root, ReduceOp op, + bool async_op, + const ReduceOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(tensor); @@ -630,38 +518,30 @@ std::shared_ptr TorchCommXCCL::reduce( auto work = createWork( stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); - // Record start event before XCCL operation work->recordStart(); const auto dataType = getXcclDataType(tensor); onecclResult_t result = xccl_api_->reduce( tensor.data_ptr(), - rank_ == root ? tensor.data_ptr() : nullptr, - tensor.numel(), - dataType, - getXcclReduceOp(op, xccl_comm_, dataType), - root, - xccl_comm_, - stream); + tensor.data_ptr(), // Use same buffer for all ranks + tensor.numel(), dataType, getXcclReduceOp(op, xccl_comm_, dataType), root, + xccl_comm_, stream); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL Reduce failed", result); } - // Record end event after XCCL operation work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::all_gather( - const std::vector& tensor_list, - const at::Tensor& tensor, - bool async_op, - const AllGatherOptions& options) { +std::shared_ptr +TorchCommXCCL::all_gather(const std::vector &tensor_list, + const at::Tensor &tensor, bool async_op, + const AllGatherOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); if (tensor_list.size() != static_cast(comm_size_)) { @@ -669,11 +549,9 @@ std::shared_ptr TorchCommXCCL::all_gather( "tensor_list size must equal comm_size for all_gather"); } - // Ensure input tensor is contiguous ensureTensorContiguous(tensor); - // Check that all output tensors are contiguous and have correct size - for (const auto& t : tensor_list) { + for (const auto &t : tensor_list) { ensureTensorContiguous(t); if (t.numel() != tensor.numel()) { throw std::runtime_error( @@ -681,8 +559,8 @@ std::shared_ptr TorchCommXCCL::all_gather( } } - tracing_->recordEventWithInputOutput( - "all_gather", rank_, tensor_list, {tensor}); + tracing_->recordEventWithInputOutput("all_gather", rank_, tensor_list, + {tensor}); xpuStream_t stream = getOperationStream(async_op); auto work = createWork( @@ -690,47 +568,39 @@ std::shared_ptr TorchCommXCCL::all_gather( work->recordStart(); - // Use multiple broadcast operations for all_gather xccl_api_->groupStart(); for (int i = 0; i < comm_size_; ++i) { - xccl_api_->broadcast( - tensor.data_ptr(), - tensor_list[i].data_ptr(), - tensor.numel(), - getXcclDataType(tensor_list[i]), - i, - xccl_comm_, - stream); + xccl_api_->broadcast(tensor.data_ptr(), tensor_list[i].data_ptr(), + tensor.numel(), getXcclDataType(tensor_list[i]), i, + xccl_comm_, stream); } xccl_api_->groupEnd(); work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::all_gather_single( - at::Tensor& output, - const at::Tensor& input, - bool async_op, - const AllGatherSingleOptions& options) { +std::shared_ptr +TorchCommXCCL::all_gather_single(at::Tensor &output, const at::Tensor &input, + bool async_op, + const AllGatherSingleOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(output); ensureTensorContiguous(input); if (output.numel() != input.numel() * comm_size_) { - throw std::runtime_error( - "Output tensor size must be input_size * comm_size for all_gather_single"); + throw std::runtime_error("Output tensor size must be input_size * " + "comm_size for all_gather_single"); } - tracing_->recordEventWithInputOutput( - "all_gather_single", rank_, {input}, {output}); + tracing_->recordEventWithInputOutput("all_gather_single", rank_, {input}, + {output}); xpuStream_t stream = getOperationStream(async_op); auto work = createWork( @@ -738,13 +608,9 @@ std::shared_ptr TorchCommXCCL::all_gather_single( work->recordStart(); - onecclResult_t result = xccl_api_->allGather( - input.data_ptr(), - output.data_ptr(), - input.numel(), - getXcclDataType(input), - xccl_comm_, - stream); + onecclResult_t result = + xccl_api_->allGather(input.data_ptr(), output.data_ptr(), input.numel(), + getXcclDataType(input), xccl_comm_, stream); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL AllGather failed", result); @@ -752,18 +618,14 @@ std::shared_ptr TorchCommXCCL::all_gather_single( work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } std::shared_ptr TorchCommXCCL::reduce_scatter( - at::Tensor& output, - const std::vector& input_list, - ReduceOp op, - bool async_op, - const ReduceScatterOptions& options) { + at::Tensor &output, const std::vector &input_list, ReduceOp op, + bool async_op, const ReduceScatterOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(output); @@ -774,7 +636,7 @@ std::shared_ptr TorchCommXCCL::reduce_scatter( } // Check that all input tensors are contiguous and have correct size - for (const auto& t : input_list) { + for (const auto &t : input_list) { ensureTensorContiguous(t); if (t.numel() != output.numel()) { throw std::runtime_error( @@ -782,14 +644,13 @@ std::shared_ptr TorchCommXCCL::reduce_scatter( } } - tracing_->recordEventWithInputOutput( - "reduce_scatter", rank_, input_list, {output}); + tracing_->recordEventWithInputOutput("reduce_scatter", rank_, input_list, + {output}); xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, - getOperationTimeout(options.timeout, options_.timeout), - input_list); + auto work = + createWork(stream, getOperationTimeout(options.timeout, options_.timeout), + input_list); work->recordStart(); @@ -798,95 +659,64 @@ std::shared_ptr TorchCommXCCL::reduce_scatter( for (int i = 0; i < comm_size_; ++i) { const auto dataType = getXcclDataType(input_list[i]); - if (i == rank_) { - // This rank receives the reduced result - xccl_api_->reduce( - input_list[i].data_ptr(), - output.data_ptr(), - output.numel(), - dataType, - getXcclReduceOp(op, xccl_comm_, dataType), - i, - xccl_comm_, - stream); - } else { - // Other ranks contribute to the reduction - xccl_api_->reduce( - input_list[i].data_ptr(), - nullptr, // Non-root ranks don't receive - input_list[i].numel(), - dataType, - getXcclReduceOp(op, xccl_comm_, dataType), - i, - xccl_comm_, - stream); - } + xccl_api_->reduce(input_list[i].data_ptr(), + i == rank_ ? output.data_ptr() : input_list[i].data_ptr(), + i == rank_ ? output.numel() : input_list[i].numel(), + dataType, getXcclReduceOp(op, xccl_comm_, dataType), i, + xccl_comm_, stream); } xccl_api_->groupEnd(); work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } std::shared_ptr TorchCommXCCL::reduce_scatter_single( - at::Tensor& output, - const at::Tensor& input, - ReduceOp op, - bool async_op, - const ReduceScatterSingleOptions& options) { + at::Tensor &output, const at::Tensor &input, ReduceOp op, bool async_op, + const ReduceScatterSingleOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(output); ensureTensorContiguous(input); if (input.numel() != output.numel() * comm_size_) { - throw std::runtime_error( - "Input tensor size must be output_size * comm_size for reduce_scatter_single"); + throw std::runtime_error("Input tensor size must be output_size * " + "comm_size for reduce_scatter_single"); } - tracing_->recordEventWithInputOutput( - "reduce_scatter_single", rank_, {input}, {output}); + tracing_->recordEventWithInputOutput("reduce_scatter_single", rank_, {input}, + {output}); xpuStream_t stream = getOperationStream(async_op); auto work = createWork( stream, getOperationTimeout(options.timeout, options_.timeout), {input}); - // Record start event before XCCL operation work->recordStart(); const auto dataType = getXcclDataType(input); onecclResult_t result = xccl_api_->reduceScatter( - input.data_ptr(), - output.data_ptr(), - output.numel(), - dataType, - getXcclReduceOp(op, xccl_comm_, dataType), - xccl_comm_, - stream); + input.data_ptr(), output.data_ptr(), output.numel(), dataType, + getXcclReduceOp(op, xccl_comm_, dataType), xccl_comm_, stream); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL ReduceScatter failed", result); } - // Record end event after XCCL operation work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::all_to_all_single( - at::Tensor& output, - const at::Tensor& input, - bool async_op, - const AllToAllSingleOptions& options) { +std::shared_ptr +TorchCommXCCL::all_to_all_single(at::Tensor &output, const at::Tensor &input, + bool async_op, + const AllToAllSingleOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(output); @@ -902,46 +732,37 @@ std::shared_ptr TorchCommXCCL::all_to_all_single( "Tensor size must be divisible by comm_size for all_to_all_single"); } - tracing_->recordEventWithInputOutput( - "all_to_all_single", rank_, {input}, {output}); + tracing_->recordEventWithInputOutput("all_to_all_single", rank_, {input}, + {output}); xpuStream_t stream = getOperationStream(async_op); auto work = createWork( stream, getOperationTimeout(options.timeout, options_.timeout), {input}); - // Record start event before XCCL operation work->recordStart(); size_t chunk_size = input.numel() / comm_size_; const auto data_type = getXcclDataType(input); - onecclResult_t result = xccl_api_->allToAll( - input.data_ptr(), - output.data_ptr(), - chunk_size, - data_type, - xccl_comm_, - stream); + onecclResult_t result = + xccl_api_->allToAll(input.data_ptr(), output.data_ptr(), chunk_size, + data_type, xccl_comm_, stream); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL AllToAll failed", result); } - // Record end event after XCCL operation work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } std::shared_ptr TorchCommXCCL::all_to_all_v_single( - at::Tensor& output, - const at::Tensor& input, - const std::vector& output_split_sizes, - const std::vector& input_split_sizes, - bool async_op, - const AllToAllvSingleOptions& options) { + at::Tensor &output, const at::Tensor &input, + const std::vector &output_split_sizes, + const std::vector &input_split_sizes, bool async_op, + const AllToAllvSingleOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(output); @@ -949,23 +770,22 @@ std::shared_ptr TorchCommXCCL::all_to_all_v_single( // Validate split sizes vectors if (input_split_sizes.size() != static_cast(comm_size_)) { - throw std::runtime_error( - "input_split_sizes length must equal comm_size for all_to_all_v_single"); + throw std::runtime_error("input_split_sizes length must equal comm_size " + "for all_to_all_v_single"); } if (output_split_sizes.size() != static_cast(comm_size_)) { - throw std::runtime_error( - "output_split_sizes length must equal comm_size for all_to_all_v_single"); + throw std::runtime_error("output_split_sizes length must equal comm_size " + "for all_to_all_v_single"); } - tracing_->recordEventWithInputOutput( - "all_to_all_v_single", rank_, {input}, {output}); + tracing_->recordEventWithInputOutput("all_to_all_v_single", rank_, {input}, + {output}); xpuStream_t stream = getOperationStream(async_op); auto work = createWork( stream, getOperationTimeout(options.timeout, options_.timeout), {input}); - // Record start event before XCCL operation work->recordStart(); // Convert split sizes to arrays and calculate displacements @@ -992,44 +812,31 @@ std::shared_ptr TorchCommXCCL::all_to_all_v_single( recvoffset += recvcounts[i]; } - char* sptr = static_cast(input.data_ptr()); - char* rptr = static_cast(output.data_ptr()); + char *sptr = static_cast(input.data_ptr()); + char *rptr = static_cast(output.data_ptr()); xccl_api_->groupStart(); for (int i = 0; i < comm_size_; ++i) { - xccl_api_->send( - sptr + senddispls[i] * type_size, - sendcounts[i], - data_type, - i, - xccl_comm_, - stream); - xccl_api_->recv( - rptr + recvdispls[i] * type_size, - recvcounts[i], - data_type, - i, - xccl_comm_, - stream); + xccl_api_->send(sptr + senddispls[i] * type_size, sendcounts[i], data_type, + i, xccl_comm_, stream); + xccl_api_->recv(rptr + recvdispls[i] * type_size, recvcounts[i], data_type, + i, xccl_comm_, stream); } xccl_api_->groupEnd(); - // Record end event after XCCL operation work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::all_to_all( - const std::vector& output_tensor_list, - const std::vector& input_tensor_list, - bool async_op, - const AllToAllOptions& options) { +std::shared_ptr +TorchCommXCCL::all_to_all(const std::vector &output_tensor_list, + const std::vector &input_tensor_list, + bool async_op, const AllToAllOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); if (output_tensor_list.size() != static_cast(comm_size_) || @@ -1044,16 +851,14 @@ std::shared_ptr TorchCommXCCL::all_to_all( ensureTensorContiguous(output_tensor_list[i]); } - tracing_->recordEventWithInputOutput( - "all_to_all", rank_, input_tensor_list, output_tensor_list); + tracing_->recordEventWithInputOutput("all_to_all", rank_, input_tensor_list, + output_tensor_list); xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, - getOperationTimeout(options.timeout, options_.timeout), - input_tensor_list); + auto work = + createWork(stream, getOperationTimeout(options.timeout, options_.timeout), + input_tensor_list); - // Record start event before XCCL operations work->recordStart(); xccl_api_->groupStart(); @@ -1061,37 +866,26 @@ std::shared_ptr TorchCommXCCL::all_to_all( for (int i = 0; i < comm_size_; ++i) { // Send to rank i xccl_api_->send( - input_tensor_list[i].data_ptr(), - input_tensor_list[i].numel(), - getXcclDataType(input_tensor_list[i]), - i, - xccl_comm_, - stream); + input_tensor_list[i].data_ptr(), input_tensor_list[i].numel(), + getXcclDataType(input_tensor_list[i]), i, xccl_comm_, stream); // Receive from rank i xccl_api_->recv( - output_tensor_list[i].data_ptr(), - output_tensor_list[i].numel(), - getXcclDataType(output_tensor_list[i]), - i, - xccl_comm_, - stream); + output_tensor_list[i].data_ptr(), output_tensor_list[i].numel(), + getXcclDataType(output_tensor_list[i]), i, xccl_comm_, stream); } xccl_api_->groupEnd(); - // Record end event after XCCL operations work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::barrier( - bool async_op, - const BarrierOptions& options) { +std::shared_ptr +TorchCommXCCL::barrier(bool async_op, const BarrierOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); @@ -1100,38 +894,28 @@ std::shared_ptr TorchCommXCCL::barrier( auto work = createWork( stream, getOperationTimeout(options.timeout, options_.timeout), {}); - // Record start event before XCCL operation work->recordStart(); // Use pre-allocated XPU buffer for barrier - onecclResult_t result = xccl_api_->allReduce( - barrier_buffer_, - barrier_buffer_, - 1, - onecclFloat32, - onecclSum, - xccl_comm_, - stream); + onecclResult_t result = + xccl_api_->allReduce(barrier_buffer_, barrier_buffer_, 1, onecclFloat32, + onecclSum, xccl_comm_, stream); if (result != onecclSuccess) { throw XCCLException(*xccl_api_, "XCCL Barrier failed", result); } - // Record end event after XCCL operation work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::scatter( - at::Tensor& output_tensor, - const std::vector& input_tensor_list, - int root, - bool async_op, - const ScatterOptions& options) { +std::shared_ptr +TorchCommXCCL::scatter(at::Tensor &output_tensor, + const std::vector &input_tensor_list, + int root, bool async_op, const ScatterOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(output_tensor); @@ -1143,7 +927,7 @@ std::shared_ptr TorchCommXCCL::scatter( "input_tensor_list size must equal comm_size for scatter"); } - for (const auto& t : input_tensor_list) { + for (const auto &t : input_tensor_list) { ensureTensorContiguous(t); if (t.numel() != output_tensor.numel()) { throw std::runtime_error( @@ -1152,20 +936,18 @@ std::shared_ptr TorchCommXCCL::scatter( } } - tracing_->recordEventWithInputOutput( - "scatter", root, input_tensor_list, {output_tensor}); + tracing_->recordEventWithInputOutput("scatter", root, input_tensor_list, + {output_tensor}); xpuStream_t stream = getOperationStream(async_op); std::vector input_tensors; if (rank_ == root) { input_tensors = input_tensor_list; } - auto work = createWork( - stream, - getOperationTimeout(options.timeout, options_.timeout), - input_tensors); + auto work = + createWork(stream, getOperationTimeout(options.timeout, options_.timeout), + input_tensors); - // Record start event before XCCL operations work->recordStart(); // Implement scatter using point-to-point operations @@ -1175,52 +957,37 @@ std::shared_ptr TorchCommXCCL::scatter( for (int i = 0; i < comm_size_; ++i) { if (i != root) { xccl_api_->send( - input_tensor_list[i].data_ptr(), - input_tensor_list[i].numel(), - getXcclDataType(input_tensor_list[i]), - i, - xccl_comm_, - stream); + input_tensor_list[i].data_ptr(), input_tensor_list[i].numel(), + getXcclDataType(input_tensor_list[i]), i, xccl_comm_, stream); } } xccl_api_->groupEnd(); // Root copies its own data using xpuMemcpyAsync - XPU_CHECK( - xpu_api_, - xpu_api_->memcpyAsync( - output_tensor.data_ptr(), - input_tensor_list[root].data_ptr(), - input_tensor_list[root].numel() * - input_tensor_list[root].element_size(), - stream), - "memcpyAsync failed"); + XPU_CHECK(xpu_api_, + xpu_api_->memcpyAsync(output_tensor.data_ptr(), + input_tensor_list[root].data_ptr(), + input_tensor_list[root].numel() * + input_tensor_list[root].element_size(), + stream), + "memcpyAsync failed"); } else { // Non-root ranks receive from root - xccl_api_->recv( - output_tensor.data_ptr(), - output_tensor.numel(), - getXcclDataType(output_tensor), - root, - xccl_comm_, - stream); + xccl_api_->recv(output_tensor.data_ptr(), output_tensor.numel(), + getXcclDataType(output_tensor), root, xccl_comm_, stream); } - // Record end event after XCCL operations work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::gather( - const std::vector& output_tensor_list, - const at::Tensor& input_tensor, - int root, - bool async_op, - const GatherOptions& options) { +std::shared_ptr +TorchCommXCCL::gather(const std::vector &output_tensor_list, + const at::Tensor &input_tensor, int root, bool async_op, + const GatherOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); ensureTensorContiguous(input_tensor); @@ -1232,7 +999,7 @@ std::shared_ptr TorchCommXCCL::gather( "output_tensor_list size must equal comm_size for gather"); } - for (const auto& t : output_tensor_list) { + for (const auto &t : output_tensor_list) { ensureTensorContiguous(t); if (t.numel() != input_tensor.numel()) { throw std::runtime_error( @@ -1241,20 +1008,18 @@ std::shared_ptr TorchCommXCCL::gather( } } - tracing_->recordEventWithInputOutput( - "gather", root, {input_tensor}, output_tensor_list); + tracing_->recordEventWithInputOutput("gather", root, {input_tensor}, + output_tensor_list); xpuStream_t stream = getOperationStream(async_op); std::vector output_tensors; if (rank_ == root) { output_tensors = output_tensor_list; } - auto work = createWork( - stream, - getOperationTimeout(options.timeout, options_.timeout), - {input_tensor}); + auto work = + createWork(stream, getOperationTimeout(options.timeout, options_.timeout), + {input_tensor}); - // Record start event before XCCL operations work->recordStart(); if (rank_ == root) { @@ -1263,61 +1028,46 @@ std::shared_ptr TorchCommXCCL::gather( for (int i = 0; i < comm_size_; ++i) { if (i != root) { xccl_api_->recv( - output_tensor_list[i].data_ptr(), - output_tensor_list[i].numel(), - getXcclDataType(output_tensor_list[i]), - i, - xccl_comm_, - stream); + output_tensor_list[i].data_ptr(), output_tensor_list[i].numel(), + getXcclDataType(output_tensor_list[i]), i, xccl_comm_, stream); } } xccl_api_->groupEnd(); // Root copies its own data using xpuMemcpyAsync - XPU_CHECK( - xpu_api_, - xpu_api_->memcpyAsync( - output_tensor_list[root].data_ptr(), - input_tensor.data_ptr(), - input_tensor.numel() * input_tensor.element_size(), - stream), - "memcpyAsync failed"); + XPU_CHECK(xpu_api_, + xpu_api_->memcpyAsync( + output_tensor_list[root].data_ptr(), input_tensor.data_ptr(), + input_tensor.numel() * input_tensor.element_size(), stream), + "memcpyAsync failed"); } else { // Non-root ranks send to root - xccl_api_->send( - input_tensor.data_ptr(), - input_tensor.numel(), - getXcclDataType(input_tensor), - root, - xccl_comm_, - stream); + xccl_api_->send(input_tensor.data_ptr(), input_tensor.numel(), + getXcclDataType(input_tensor), root, xccl_comm_, stream); } - // Record end event after XCCL operations work->recordEnd(); - // Enqueue the work after events have been recorded enqueueWork(work, stream); return work; } -std::shared_ptr TorchCommXCCL::split( - const std::vector& ranks, - const std::string& name, - const CommOptions& options) { +std::shared_ptr +TorchCommXCCL::split(const std::vector &ranks, const std::string &name, + const CommOptions &options) { // Validate the ranks list checkAndAbortIfTimedOutOrError(); std::unordered_set rank_seen; for (int rank : ranks) { if (rank < 0 || rank >= comm_size_) { - throw std::runtime_error( - "Invalid rank " + std::to_string(rank) + - " in ranks. Valid ranks are 0 to " + std::to_string(comm_size_ - 1)); + throw std::runtime_error("Invalid rank " + std::to_string(rank) + + " in ranks. Valid ranks are 0 to " + + std::to_string(comm_size_ - 1)); } if (rank_seen.find(rank) != rank_seen.end()) { - throw std::runtime_error( - "Rank " + std::to_string(rank) + " appears multiple times in ranks"); + throw std::runtime_error("Rank " + std::to_string(rank) + + " appears multiple times in ranks"); } rank_seen.insert(rank); } @@ -1339,9 +1089,8 @@ std::shared_ptr TorchCommXCCL::split( auto it = std::find(ranks.begin(), ranks.end(), rank_); if (it == ranks.end()) { // Current rank is not in the non-empty list - this is an error - throw std::runtime_error( - "Current rank " + std::to_string(rank_) + - " is not included in the provided ranks list"); + throw std::runtime_error("Current rank " + std::to_string(rank_) + + " is not included in the provided ranks list"); } // Set color to the lowest rank in the group and calculate new rank color = *std::min_element(ranks.begin(), ranks.end()); @@ -1356,11 +1105,6 @@ std::shared_ptr TorchCommXCCL::split( // Populate XCCL config from user-provided hints populateXcclConfigFromHints(config, options, name); - // TODO: xccl says that this is not supposed to be called if any operation - // is outstanding on the comm. We should check for that. - // TODO: what happens if one rank fails but the others succeed, need to - // handle the error case. - // TODO: is this sharing any resources with the original comm? onecclResult_t result = xccl_api_->commSplit(xccl_comm_, color, new_rank, &new_comm, &config); if (result != onecclSuccess) { @@ -1380,7 +1124,7 @@ std::shared_ptr TorchCommXCCL::split( } void TorchCommXCCL::register_address( - const TorchCommXCCL::AddressWithLen& addr) { + const TorchCommXCCL::AddressWithLen &addr) { // We got a register after we got rid of the comm. Is this a fatal error? if (!xccl_comm_) { return; @@ -1389,18 +1133,17 @@ void TorchCommXCCL::register_address( if (memoryRegistrationHandles_.contains(addr.addr)) { throw std::runtime_error("Memory already registered with XCCL"); } - void* handle = nullptr; + void *handle = nullptr; onecclResult_t result = xccl_api_->commRegister(xccl_comm_, addr.addr, addr.len, &handle); if (result != onecclSuccess) { - throw std::runtime_error( - "Failed to register memory with XCCL: " + - std::string(onecclGetErrorString(result))); + throw std::runtime_error("Failed to register memory with XCCL: " + + std::string(onecclGetErrorString(result))); } memoryRegistrationHandles_.emplace(addr.addr, RegistrationHandle(handle)); } -void TorchCommXCCL::deregister_address(const TorchCommXCCL::Address& addr) { +void TorchCommXCCL::deregister_address(const TorchCommXCCL::Address &addr) { // We got a deregister after we got rid of the comm. Is this a fatal error? if (!xccl_comm_) { return; @@ -1408,41 +1151,34 @@ void TorchCommXCCL::deregister_address(const TorchCommXCCL::Address& addr) { auto it = memoryRegistrationHandles_.find(addr.addr); if (it == memoryRegistrationHandles_.end()) { - // it's possible that the memory was registered for a different comm, + // it' possible that the memory was registered for a different comm, // however failed registration for this comm. return; } - void* handle = it->second.regHandle; + void *handle = it->second.regHandle; onecclResult_t result = xccl_api_->commDeregister(xccl_comm_, handle); if (result != onecclSuccess) { - throw std::runtime_error( - "Failed to deregister memory with XCCL: " + - std::string(xccl_api_->getErrorString(result))); + throw std::runtime_error("Failed to deregister memory with XCCL: " + + std::string(xccl_api_->getErrorString(result))); } memoryRegistrationHandles_.erase(it); } -XCCLException::XCCLException( - XcclApi& xccl_api, - const std::string& message, - onecclResult_t result) +XCCLException::XCCLException(XcclApi &xccl_api, const std::string &message, + onecclResult_t result) : message_(message + ": " + xccl_api.getErrorString(result)), result_(result) {} -const char* XCCLException::what() const noexcept { - return message_.c_str(); -} - - +const char *XCCLException::what() const noexcept { return message_.c_str(); } } // namespace comms } // namespace torch namespace { class XCCLRegistration { - public: +public: XCCLRegistration() { torch::comms::TorchCommFactory::get().register_backend("xccl", []() { return std::make_shared(); diff --git a/comms/torchcomms/xccl/TorchCommXCCL.hpp b/comms/torchcomms/xccl/TorchCommXCCL.hpp index a4b3af2b..c813fe25 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.hpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.hpp @@ -1,5 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #pragma once #include @@ -32,35 +30,34 @@ constexpr size_t kMaxEventPoolSize = 1000; // Custom exception class for better error handling class XCCLException : public std::exception { - public: - XCCLException(XcclApi& api, const std::string& message, onecclResult_t result); +public: + XCCLException(XcclApi &api, const std::string &message, + onecclResult_t result); - const char* what() const noexcept override; + const char *what() const noexcept override; onecclResult_t getResult() const; - private: +private: std::string message_; onecclResult_t result_; }; class TorchCommXCCL : public TorchCommBackend, public std::enable_shared_from_this { - public: +public: static constexpr std::string_view kBackendName = "xccl"; TorchCommXCCL(); ~TorchCommXCCL() override; // Delete copy and move operations - TorchCommXCCL(const TorchCommXCCL&) = delete; - TorchCommXCCL(TorchCommXCCL&&) = delete; - TorchCommXCCL& operator=(const TorchCommXCCL&) = delete; - TorchCommXCCL& operator=(TorchCommXCCL&&) = delete; - - void init( - at::Device device, - const std::string& name, - const CommOptions& options = {}) override; + TorchCommXCCL(const TorchCommXCCL &) = delete; + TorchCommXCCL(TorchCommXCCL &&) = delete; + TorchCommXCCL &operator=(const TorchCommXCCL &) = delete; + TorchCommXCCL &operator=(TorchCommXCCL &&) = delete; + + void init(at::Device device, const std::string &name, + const CommOptions &options = {}) override; void finalize() override; int getRank() const override; int getSize() const override; @@ -68,102 +65,71 @@ class TorchCommXCCL : public TorchCommBackend, std::string_view getCommName() const override; // Point-to-Point Operations - std::shared_ptr send( - const at::Tensor& tensor, - int dst, - bool async_op, - const SendOptions& options = {}) override; - std::shared_ptr recv( - at::Tensor& tensor, - int src, - bool async_op, - const RecvOptions& options = {}) override; + std::shared_ptr send(const at::Tensor &tensor, int dst, + bool async_op, + const SendOptions &options = {}) override; + std::shared_ptr recv(at::Tensor &tensor, int src, bool async_op, + const RecvOptions &options = {}) override; // Batch P2P Operations - std::shared_ptr batch_op_issue( - const std::vector& ops, - bool async_op, - const BatchP2POptions& options = {}) override; + std::shared_ptr + batch_op_issue(const std::vector &ops, bool async_op, + const BatchP2POptions &options = {}) override; // Collective Operations - std::shared_ptr broadcast( - at::Tensor& tensor, - int root, - bool async_op, - const BroadcastOptions& options = {}) override; - std::shared_ptr all_reduce( - at::Tensor& tensor, - ReduceOp op, - bool async_op, - const AllReduceOptions& options = {}) override; - std::shared_ptr reduce( - const at::Tensor& tensor, - int root, - ReduceOp op, - bool async_op, - const ReduceOptions& options = {}) override; - std::shared_ptr all_gather( - const std::vector& tensor_list, - const at::Tensor& tensor, - bool async_op, - const AllGatherOptions& options = {}) override; - std::shared_ptr all_gather_single( - at::Tensor& output, - const at::Tensor& input, - bool async_op, - const AllGatherSingleOptions& options = {}) override; - std::shared_ptr reduce_scatter( - at::Tensor& output, - const std::vector& input_list, - ReduceOp op, - bool async_op, - const ReduceScatterOptions& options = {}) override; + std::shared_ptr + broadcast(at::Tensor &tensor, int root, bool async_op, + const BroadcastOptions &options = {}) override; + std::shared_ptr + all_reduce(at::Tensor &tensor, ReduceOp op, bool async_op, + const AllReduceOptions &options = {}) override; + std::shared_ptr reduce(const at::Tensor &tensor, int root, + ReduceOp op, bool async_op, + const ReduceOptions &options = {}) override; + std::shared_ptr + all_gather(const std::vector &tensor_list, + const at::Tensor &tensor, bool async_op, + const AllGatherOptions &options = {}) override; + std::shared_ptr + all_gather_single(at::Tensor &output, const at::Tensor &input, bool async_op, + const AllGatherSingleOptions &options = {}) override; + std::shared_ptr + reduce_scatter(at::Tensor &output, const std::vector &input_list, + ReduceOp op, bool async_op, + const ReduceScatterOptions &options = {}) override; std::shared_ptr reduce_scatter_single( - at::Tensor& output, - const at::Tensor& input, - ReduceOp op, - bool async_op, - const ReduceScatterSingleOptions& options = {}) override; - std::shared_ptr all_to_all_single( - at::Tensor& output, - const at::Tensor& input, - bool async_op, - const AllToAllSingleOptions& options = {}) override; - std::shared_ptr all_to_all_v_single( - at::Tensor& output, - const at::Tensor& input, - const std::vector& output_split_sizes, - const std::vector& input_split_sizes, - bool async_op, - const AllToAllvSingleOptions& options = {}) override; - std::shared_ptr all_to_all( - const std::vector& output_tensor_list, - const std::vector& input_tensor_list, - bool async_op, - const AllToAllOptions& options = {}) override; - std::shared_ptr barrier( - bool async_op, - const BarrierOptions& options = {}) override; + at::Tensor &output, const at::Tensor &input, ReduceOp op, bool async_op, + const ReduceScatterSingleOptions &options = {}) override; + std::shared_ptr + all_to_all_single(at::Tensor &output, const at::Tensor &input, bool async_op, + const AllToAllSingleOptions &options = {}) override; + std::shared_ptr + all_to_all_v_single(at::Tensor &output, const at::Tensor &input, + const std::vector &output_split_sizes, + const std::vector &input_split_sizes, + bool async_op, + const AllToAllvSingleOptions &options = {}) override; + std::shared_ptr + all_to_all(const std::vector &output_tensor_list, + const std::vector &input_tensor_list, bool async_op, + const AllToAllOptions &options = {}) override; + std::shared_ptr + barrier(bool async_op, const BarrierOptions &options = {}) override; // Scatter and Gather Operations - std::shared_ptr scatter( - at::Tensor& output_tensor, - const std::vector& input_tensor_list, - int root, - bool async_op, - const ScatterOptions& options = {}) override; - std::shared_ptr gather( - const std::vector& output_tensor_list, - const at::Tensor& input_tensor, - int root, - bool async_op, - const GatherOptions& options = {}) override; + std::shared_ptr + scatter(at::Tensor &output_tensor, + const std::vector &input_tensor_list, int root, + bool async_op, const ScatterOptions &options = {}) override; + std::shared_ptr + gather(const std::vector &output_tensor_list, + const at::Tensor &input_tensor, int root, bool async_op, + const GatherOptions &options = {}) override; // Communicator Management - std::shared_ptr split( - const std::vector& ranks, - const std::string& name, - const CommOptions& options = {}) override; + std::shared_ptr + split(const std::vector &ranks, const std::string &name, + const CommOptions &options = {}) override; // Friend access for TorchCommXCCL friend class TorchWorkXCCL; @@ -172,37 +138,25 @@ class TorchCommXCCL : public TorchCommBackend, friend class TorchCommWindowXCCL; // Getter for CUDA API (for friend classes) - XpuApi* getXpuApi() const { - return xpu_api_.get(); - } + XpuApi *getXpuApi() const { return xpu_api_.get(); } // Getter for XCCL API (for friend classes) - XcclApi* getXcclApi() const { - return xccl_api_.get(); - } + XcclApi *getXcclApi() const { return xccl_api_.get(); } // Method to override the XCCL API implementation for testing - void setXcclApi(std::shared_ptr api) { - xccl_api_ = std::move(api); - } + void setXcclApi(std::shared_ptr api) { xccl_api_ = std::move(api); } // Method to override the CUDA API implementation for testing - void setXpuApi(std::shared_ptr api) { - xpu_api_ = std::move(api); - } + void setXpuApi(std::shared_ptr api) { xpu_api_ = std::move(api); } - const CommOptions& getOptions() const override { - return options_; - } + const CommOptions &getOptions() const override { return options_; } - const at::Device& getDevice() const override { - return device_; - } + const at::Device &getDevice() const override { return device_; } - protected: +protected: // Event management for friend classes xpuEvent_t getEvent(); - void returnEvent(xpuEvent_t&& event); + void returnEvent(xpuEvent_t &&event); void abortXcclComm(); enum class CommState { @@ -212,47 +166,42 @@ class TorchCommXCCL : public TorchCommBackend, }; struct Address { - void* addr; + void *addr; }; struct AddressWithLen { - void* addr; + void *addr; size_t len; }; std::atomic comm_state_{ CommState::NORMAL}; // State of the communicator - void register_address(const AddressWithLen& addr); - void deregister_address(const Address& addr); - onecclDataType_t getXcclDataType(const at::Tensor& tensor); - std::shared_ptr createWork( - xpuStream_t stream, - std::chrono::milliseconds timeout, - const std::vector& inputTensors); + void register_address(const AddressWithLen &addr); + void deregister_address(const Address &addr); + onecclDataType_t getXcclDataType(const at::Tensor &tensor); + std::shared_ptr + createWork(xpuStream_t stream, std::chrono::milliseconds timeout, + const std::vector &inputTensors); - private: +private: // Helper that automatically cleans up premul sums. struct RedOpRAII { /* implicit */ RedOpRAII(onecclRedOp_t op); // Constructor for Premulsum Reduction - explicit RedOpRAII( - const ReduceOp& op, - const onecclComm_t comm, - const onecclDataType_t dataType, - std::shared_ptr xccl_api); + explicit RedOpRAII(const ReduceOp &op, const onecclComm_t comm, + const onecclDataType_t dataType, + std::shared_ptr xccl_api); RedOpRAII() = delete; - RedOpRAII(const RedOpRAII&) = delete; - RedOpRAII& operator=(const RedOpRAII&) = delete; - RedOpRAII(RedOpRAII&& tmp) = delete; - RedOpRAII& operator=(RedOpRAII&&) = delete; + RedOpRAII(const RedOpRAII &) = delete; + RedOpRAII &operator=(const RedOpRAII &) = delete; + RedOpRAII(RedOpRAII &&tmp) = delete; + RedOpRAII &operator=(RedOpRAII &&) = delete; ~RedOpRAII(); - operator onecclRedOp_t() const { - return xcclRedOp_; - } + operator onecclRedOp_t() const { return xcclRedOp_; } onecclRedOp_t xcclRedOp_{onecclMaxRedOp}; onecclComm_t comm_{nullptr}; @@ -261,18 +210,18 @@ class TorchCommXCCL : public TorchCommBackend, // Struct to hold the registration handle for a buffer struct RegistrationHandle { - void* regHandle; + void *regHandle; - explicit RegistrationHandle(void* regHandle) : regHandle{regHandle} {} + explicit RegistrationHandle(void *regHandle) : regHandle{regHandle} {} - RegistrationHandle(RegistrationHandle&& other) noexcept + RegistrationHandle(RegistrationHandle &&other) noexcept : regHandle{other.regHandle} { other.regHandle = nullptr; } - RegistrationHandle(const RegistrationHandle&) = delete; - RegistrationHandle& operator=(const RegistrationHandle&) = delete; - RegistrationHandle& operator=(RegistrationHandle&&) = delete; + RegistrationHandle(const RegistrationHandle &) = delete; + RegistrationHandle &operator=(const RegistrationHandle &) = delete; + RegistrationHandle &operator=(RegistrationHandle &&) = delete; ~RegistrationHandle() = default; }; @@ -282,19 +231,15 @@ class TorchCommXCCL : public TorchCommBackend, // Private utility methods size_t wordSize(onecclDataType_t type) const; - RedOpRAII getXcclReduceOp( - const ReduceOp& op, - const onecclComm_t comm, - const onecclDataType_t dataType); + RedOpRAII getXcclReduceOp(const ReduceOp &op, const onecclComm_t comm, + const onecclDataType_t dataType); void timeoutWatchdog() noexcept; void checkInitialized() const; void checkAndAbortIfTimedOutOrError(); void checkWorkQueue(bool isMainThread); void enqueueWork(std::shared_ptr work, xpuStream_t stream); - // XPU doesn't support graph capture yet - // bool getGraphCaptureMode(); xpuStream_t getOperationStream(bool async_op); - void ensureTensorContiguous(const at::Tensor& tensor); + void ensureTensorContiguous(const at::Tensor &tensor); void attachMemoryHook(); void detachMemoryHook(); @@ -308,8 +253,8 @@ class TorchCommXCCL : public TorchCommBackend, size_t max_event_pool_size_{}; std::optional internal_stream_; // Initialized in init() std::optional - dependency_event_; // Pre-allocated event for stream dependencies - void* barrier_buffer_{}; // Pre-allocated CUDA buffer for barrier operations + dependency_event_; // Pre-allocated event for stream dependencies + void *barrier_buffer_{}; // Pre-allocated CUDA buffer for barrier operations enum class InitializationState { UNINITIALIZED, INITIALIZED, @@ -318,7 +263,7 @@ class TorchCommXCCL : public TorchCommBackend, // List of [comm, regHandlesMap] pairs. Each regHandlesMap is a map from the // buffer address to the registeration handle - std::map memoryRegistrationHandles_; + std::map memoryRegistrationHandles_; // XCCL API abstraction std::shared_ptr xccl_api_; @@ -346,26 +291,26 @@ class TorchCommXCCL : public TorchCommBackend, // Graph capture mode work references // Keep references to work objects during graph capture to prevent premature // destruction, organized per graph using capture ID - std::unordered_map< - unsigned long long, - std::vector>> + std::unordered_map>> graph_capture_work_refs_; std::mutex graph_capture_work_mutex_; // Structure to hold cleanup data for XPU user objects // NOTE: Graph capture cleanup is currently disabled for XPU/SYCL - // as the required APIs (userObjectCreate, graphRetainUserObject) are not yet available + // as the required APIs (userObjectCreate, graphRetainUserObject) are not yet + // available struct GraphCleanupData { - TorchCommXCCL* comm; + TorchCommXCCL *comm; unsigned long long graph_id; - GraphCleanupData(TorchCommXCCL* comm_, unsigned long long id) + GraphCleanupData(TorchCommXCCL *comm_, unsigned long long id) : comm(comm_), graph_id(id) {} }; // Static callback function for XPU user object cleanup - // NOTE: Currently disabled - XPU/SYCL does not have equivalent callback mechanism - // static void graphCleanupCallback(void* userData); + // NOTE: Currently disabled - XPU/SYCL does not have equivalent callback + // mechanism static void graphCleanupCallback(void* userData); friend class TorchWorkXCCLQueueCommTest; }; diff --git a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp index 5462d5bf..14b91015 100644 --- a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp @@ -1,12 +1,10 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #include "comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp" -#include -#include // @manual #include "comms/torchcomms/StoreManager.hpp" #include "comms/torchcomms/TorchCommLogging.hpp" #include "comms/torchcomms/TorchCommUtils.hpp" #include "comms/torchcomms/xccl/TorchCommXCCL.hpp" +#include +#include // @manual namespace torch { namespace comms { @@ -19,23 +17,17 @@ const std::string kUniqueidXchgMethodTCPStore = "tcpstore"; const std::string kUniqueidXchgMethodDefault = kUniqueidXchgMethodAuto; TorchCommXCCLBootstrap::TorchCommXCCLBootstrap( - c10::intrusive_ptr store, - c10::Device device, - std::shared_ptr xccl_api, - std::shared_ptr xpu_api, + c10::intrusive_ptr store, c10::Device device, + std::shared_ptr xccl_api, std::shared_ptr xpu_api, std::chrono::milliseconds timeout) - : timeout_(timeout), - store_(store), - created_internal_store_(false), - device_(device), - xccl_api_(xccl_api), - xpu_api_(xpu_api) { + : timeout_(timeout), store_(store), created_internal_store_(false), + device_(device), xccl_api_(xccl_api), xpu_api_(xpu_api) { // Query rank and size using the utility function auto ranksize = query_ranksize(); rank_ = ranksize.first; comm_size_ = ranksize.second; - const char* uniqueid_xchg_env = + const char *uniqueid_xchg_env = std::getenv("TORCHCOMM_XCCL_BOOTSTRAP_UNIQUEID_EXCHANGE_METHOD"); if (uniqueid_xchg_env == nullptr) { TC_LOG(INFO) @@ -45,42 +37,32 @@ TorchCommXCCLBootstrap::TorchCommXCCLBootstrap( } else { uniqueid_xchg_method_ = uniqueid_xchg_env; } - std::transform( - uniqueid_xchg_method_.begin(), - uniqueid_xchg_method_.end(), - uniqueid_xchg_method_.begin(), - [](unsigned char c) { return std::tolower(c); }); + std::transform(uniqueid_xchg_method_.begin(), uniqueid_xchg_method_.end(), + uniqueid_xchg_method_.begin(), + [](unsigned char c) { return std::tolower(c); }); if (device_.index() == -1) { int device_count; - XPU_CHECK( - xpu_api_, - xpu_api_->getDeviceCount(&device_count), - "Failed to get XPU device count"); + XPU_CHECK(xpu_api_, xpu_api_->getDeviceCount(&device_count), + "Failed to get XPU device count"); device_ = c10::Device(c10::kXPU, rank_ % device_count); TC_LOG(INFO) << "User did not provide device ID; using device xpu:" << static_cast(device_.index()); } - XPU_CHECK( - xpu_api_, - xpu_api_->setDevice(device_.index()), - "Failed to set device to " + std::to_string(device_.index())); + XPU_CHECK(xpu_api_, xpu_api_->setDevice(device_.index()), + "Failed to set device to " + std::to_string(device_.index())); // Allocate XPU memory for a single float32 value used in barrier operations - XPU_CHECK( - xpu_api_, - xpu_api_->malloc(&barrier_buffer_, sizeof(float)), - "Failed to allocate barrier buffer"); + XPU_CHECK(xpu_api_, xpu_api_->malloc(&barrier_buffer_, sizeof(float)), + "Failed to allocate barrier buffer"); } TorchCommXCCLBootstrap::~TorchCommXCCLBootstrap() { if (barrier_buffer_ != nullptr) { - XPU_CHECK( - xpu_api_, - xpu_api_->free(barrier_buffer_), - "Failed to free barrier buffer"); + XPU_CHECK(xpu_api_, xpu_api_->free(barrier_buffer_), + "Failed to free barrier buffer"); barrier_buffer_ = nullptr; } } @@ -95,53 +77,42 @@ std::string TorchCommXCCLBootstrap::getXCCLStoreKeyPrefix() { return "xccl_storekey_"; }; -int TorchCommXCCLBootstrap::getXCCLStoreKeyCounter() { - return counter_; -} +int TorchCommXCCLBootstrap::getXCCLStoreKeyCounter() { return counter_; } onecclUniqueId TorchCommXCCLBootstrap::exchangeUniqueIdStore() { onecclUniqueId uniqueId; auto key = getXCCLStoreKey(); - TC_LOG(INFO) << "[TC] Using store key: " << key << " for rank " << rank_; - + if (rank_ == 0) { // Generate unique ID on rank 0 - TC_LOG(INFO) << "[TC] Rank 0: calling getUniqueId"; onecclResult_t xcclErr = xccl_api_->getUniqueId(&uniqueId); - TC_LOG(INFO) << "[TC] Rank 0: getUniqueId returned " << xcclErr; - + if (xcclErr != onecclSuccess) { - throw std::runtime_error( - "Failed to get XCCL unique ID: " + - std::string(xccl_api_->getErrorString(xcclErr))); + throw std::runtime_error("Failed to get XCCL unique ID: " + + std::string(xccl_api_->getErrorString(xcclErr))); } // Set the unique ID in the store - TC_LOG(INFO) << "[TC] Rank 0: setting unique ID in store"; - std::vector vec( - reinterpret_cast(&uniqueId), - reinterpret_cast(&uniqueId) + sizeof(uniqueId)); + std::vector vec(reinterpret_cast(&uniqueId), + reinterpret_cast(&uniqueId) + + sizeof(uniqueId)); store_->set(key, vec); - TC_LOG(INFO) << "[TC] Rank 0: unique ID set in store"; } else { // Other ranks read the broadcast ID - TC_LOG(INFO) << "[TC] Rank " << rank_ << ": getting unique ID from store"; auto vec = store_->get(key); - TC_LOG(INFO) << "[TC] Rank " << rank_ << ": got unique ID from store, size=" << vec.size(); - + if (vec.size() != sizeof(onecclUniqueId)) { throw std::runtime_error("Invalid XCCL unique ID size"); } - uniqueId = *(reinterpret_cast(vec.data())); + uniqueId = *(reinterpret_cast(vec.data())); } - TC_LOG(INFO) << "[TC] Rank " << rank_ << ": unique ID exchange completed"; return uniqueId; } -onecclUniqueId TorchCommXCCLBootstrap::exchangeUniqueIdTCPStore( - std::string_view name) { +onecclUniqueId +TorchCommXCCLBootstrap::exchangeUniqueIdTCPStore(std::string_view name) { store_ = StoreManager::get().getStore(TorchCommXCCL::kBackendName, name, timeout_); created_internal_store_ = true; @@ -161,8 +132,8 @@ onecclUniqueId TorchCommXCCLBootstrap::exchangeUniqueId(std::string_view name) { bool is_tcp_store_enabled = isTCPStoreEnabled(); if (uniqueid_xchg_method_ != kUniqueidXchgMethodAuto && uniqueid_xchg_method_ != kUniqueidXchgMethodTCPStore) { - throw std::runtime_error( - "Invalid unique ID exchange method " + uniqueid_xchg_method_); + throw std::runtime_error("Invalid unique ID exchange method " + + uniqueid_xchg_method_); } if (!is_tcp_store_enabled) { throw std::runtime_error("No way to exchange unique ID"); @@ -179,38 +150,30 @@ void TorchCommXCCLBootstrap::cleanupTCPStore(onecclComm_t xccl_comm) { store_.reset(); auto stream = xpu_api_->getCurrentXPUStream(device_.index()); - onecclResult_t result = xccl_api_->allReduce( - barrier_buffer_, - barrier_buffer_, - 1, - onecclFloat32, - onecclSum, - xccl_comm, - stream); + onecclResult_t result = + xccl_api_->allReduce(barrier_buffer_, barrier_buffer_, 1, onecclFloat32, + onecclSum, xccl_comm, stream); if (result != onecclSuccess) { TC_LOG(ERROR) << "XCCL AllReduce failed: " << xccl_api_->getErrorString(result); } - XPU_CHECK( - xpu_api_, - xpu_api_->streamSynchronize(stream), - "Stream synchronization failed"); + XPU_CHECK(xpu_api_, xpu_api_->streamSynchronize(stream), + "Stream synchronization failed"); } } // Helper function to populate XCCL config from hints -void populateXcclConfigFromHints( - onecclConfig_t& config, - const CommOptions& options, - const std::string& name) { +void populateXcclConfigFromHints(onecclConfig_t &config, + const CommOptions &options, + const std::string &name) { // Iterate over the hints and set the corresponding fields in the config. For // string arguments, XCCL uses a "const char*" instead of a std::string, so // it is hard to figure out the ownership structure. Here, we create a copy // of the string and pass it to XCCL, so that it is responsible for freeing // it. - for (const auto& [key, val] : options.hints) { + for (const auto &[key, val] : options.hints) { if (key == "blocking") { config.blocking = std::stoi(val); TC_LOG(INFO) << "[comm=" << name @@ -236,17 +199,17 @@ void populateXcclConfigFromHints( TC_LOG(INFO) << "[comm=" << name << "] Setting config.splitShare=" << config.splitShare; } else if (key == "trafficClass" || key == "traffic_class" || - key == "commName" || - key == "collnetEnable" || key == "collnet_enable" || - key == "CTAPolicy" || key == "cta_policy" || - key == "shrinkShare" || + key == "commName" || key == "collnetEnable" || + key == "collnet_enable" || key == "CTAPolicy" || + key == "cta_policy" || key == "shrinkShare" || key == "nvlsCTAs" || key == "nvls_ctas" || - key == "nChannelsPerNetPeer" || key == "n_channels_per_net_peer" || + key == "nChannelsPerNetPeer" || + key == "n_channels_per_net_peer" || key == "nvlinkCentricSched" || key == "nvlink_centric_sched") { - TC_LOG(WARNING) - << "XCCL hint '" << key - << "' is NCCL-specific and not supported by oneCCL, ignoring for comm '" - << name << "'"; + TC_LOG(WARNING) << "XCCL hint '" << key + << "' is NCCL-specific and not supported by oneCCL, " + "ignoring for comm '" + << name << "'"; } else { TC_LOG(WARNING) << "XCCL hint '" << key @@ -256,48 +219,34 @@ void populateXcclConfigFromHints( } } -onecclComm_t TorchCommXCCLBootstrap::createXcclComm( - const std::string& name, - const CommOptions& options) { +onecclComm_t +TorchCommXCCLBootstrap::createXcclComm(const std::string &name, + const CommOptions &options) { onecclUniqueId uniqueId; onecclComm_t xccl_comm = nullptr; - TC_LOG(INFO) << "[TC] Exchanging unique ID for comm '" << name << "'"; uniqueId = exchangeUniqueId(name); - TC_LOG(INFO) << "[TC] Unique ID exchanged"; - // TODO: add logging on failures and successes - // TODO: use scalable init - // TODO: get the local rank - TC_LOG(INFO) << "[TC] Initializing XCCL config"; onecclConfig_t config = ONECCL_CONFIG_INITIALIZER; - // Note: oneCCL does not have a commName field like NCCL // Populate XCCL config from user-provided hints populateXcclConfigFromHints(config, options, name); // Set device for oneCCL before initializing communicator - TC_LOG(INFO) << "[TC] Setting oneCCL device to " << device_.index(); onecclResult_t xcclErr = xccl_api_->setDevice(device_.index()); if (xcclErr != onecclSuccess) { - throw std::runtime_error( - "Failed to set oneCCL device: " + - std::string(xccl_api_->getErrorString(xcclErr))); + throw std::runtime_error("Failed to set oneCCL device: " + + std::string(xccl_api_->getErrorString(xcclErr))); } - TC_LOG(INFO) << "[TC] Calling commInitRankConfig with rank=" << rank_ - << " comm_size=" << comm_size_; - xcclErr = xccl_api_->commInitRankConfig( - &xccl_comm, comm_size_, uniqueId, rank_, &config); - TC_LOG(INFO) << "[TC] commInitRankConfig returned: " << xcclErr; - + xcclErr = xccl_api_->commInitRankConfig(&xccl_comm, comm_size_, uniqueId, + rank_, &config); + if (xcclErr != onecclSuccess || xccl_comm == nullptr) { - throw std::runtime_error( - "Failed to initialize XCCL communicator: " + - std::string(xccl_api_->getErrorString(xcclErr))); + throw std::runtime_error("Failed to initialize XCCL communicator: " + + std::string(xccl_api_->getErrorString(xcclErr))); } - TC_LOG(INFO) << "[TC] XCCL communicator initialized, cleaning up TCPStore"; cleanupTCPStore(xccl_comm); return xccl_comm; diff --git a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp index 152611ef..6b43f555 100644 --- a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp +++ b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp @@ -1,11 +1,8 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #pragma once #include #include -// #include // @manual=third-party//xpu:xpu-lazy #include // @manual=//caffe2:torch-cpp #include "comms/torchcomms/TorchCommOptions.hpp" @@ -20,46 +17,37 @@ namespace comms { constexpr uint16_t kTCPStorePort = 29500; class TorchCommXCCLBootstrap { - public: - TorchCommXCCLBootstrap( - c10::intrusive_ptr store, - c10::Device device, - std::shared_ptr xccl_api, - std::shared_ptr xpu_api, - std::chrono::milliseconds timeout); +public: + TorchCommXCCLBootstrap(c10::intrusive_ptr store, + c10::Device device, std::shared_ptr xccl_api, + std::shared_ptr xpu_api, + std::chrono::milliseconds timeout); ~TorchCommXCCLBootstrap(); // Delete copy and move operations - TorchCommXCCLBootstrap(const TorchCommXCCLBootstrap&) = delete; - TorchCommXCCLBootstrap& operator=(const TorchCommXCCLBootstrap&) = delete; - TorchCommXCCLBootstrap(TorchCommXCCLBootstrap&&) = delete; - TorchCommXCCLBootstrap& operator=(TorchCommXCCLBootstrap&&) = delete; + TorchCommXCCLBootstrap(const TorchCommXCCLBootstrap &) = delete; + TorchCommXCCLBootstrap &operator=(const TorchCommXCCLBootstrap &) = delete; + TorchCommXCCLBootstrap(TorchCommXCCLBootstrap &&) = delete; + TorchCommXCCLBootstrap &operator=(TorchCommXCCLBootstrap &&) = delete; - onecclComm_t createXcclComm( - const std::string& name, - const CommOptions& options = {}); + onecclComm_t createXcclComm(const std::string &name, + const CommOptions &options = {}); static std::string getXCCLStoreKey(); static std::string getXCCLStoreKeyPrefix(); static int getXCCLStoreKeyCounter(); - int getRank() { - return rank_; - } - int getSize() { - return comm_size_; - } - c10::Device getDevice() { - return device_; - } + int getRank() { return rank_; } + int getSize() { return comm_size_; } + c10::Device getDevice() { return device_; } - private: +private: onecclUniqueId exchangeUniqueId(std::string_view name); onecclUniqueId exchangeUniqueIdStore(); onecclUniqueId exchangeUniqueIdTCPStore(std::string_view name); bool isTCPStoreEnabled(); void cleanupTCPStore(onecclComm_t xccl_comm); - private: +private: const std::chrono::milliseconds timeout_; static int counter_; @@ -68,7 +56,7 @@ class TorchCommXCCLBootstrap { c10::Device device_; std::shared_ptr xccl_api_; std::shared_ptr xpu_api_; - void* barrier_buffer_{nullptr}; + void *barrier_buffer_{nullptr}; int rank_; int comm_size_; @@ -76,10 +64,9 @@ class TorchCommXCCLBootstrap { }; // Helper function to populate XCCL config from hints -void populateXcclConfigFromHints( - onecclConfig_t& config, - const CommOptions& options, - const std::string& name); +void populateXcclConfigFromHints(onecclConfig_t &config, + const CommOptions &options, + const std::string &name); } // namespace comms } // namespace torch diff --git a/comms/torchcomms/xccl/TorchCommXCCLPy.cpp b/comms/torchcomms/xccl/TorchCommXCCLPy.cpp index aa1780ed..aa9f0829 100644 --- a/comms/torchcomms/xccl/TorchCommXCCLPy.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCLPy.cpp @@ -1,5 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #include #include #include diff --git a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp index 79c35d94..a96901e3 100644 --- a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp @@ -1,59 +1,54 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #include "comms/torchcomms/xccl/TorchCommXCCL.hpp" // #include "comms/torchcomms/xccl/TorchCommXCCLCCA.hpp" -#include -#include #include "comms/torchcomms/TorchCommLogging.hpp" #include #include +#include +#include namespace torch { namespace comms { namespace { -onecclDataType_t getXcclDataTypeInternal(const at::Tensor& tensor) { +onecclDataType_t getXcclDataTypeInternal(const at::Tensor &tensor) { switch (tensor.scalar_type()) { - case at::ScalarType::Float: - return onecclFloat32; - case at::ScalarType::Double: - return onecclFloat64; - case at::ScalarType::Half: - return onecclFloat16; - case at::ScalarType::BFloat16: - return onecclBfloat16; - case at::ScalarType::Int: - return onecclInt32; - case at::ScalarType::Long: - return onecclInt64; - case at::ScalarType::Char: - return onecclInt8; - case at::ScalarType::Byte: - return onecclUint8; - default: - throw std::runtime_error("Unsupported tensor data type for XCCL"); + case at::ScalarType::Float: + return onecclFloat32; + case at::ScalarType::Double: + return onecclFloat64; + case at::ScalarType::Half: + return onecclFloat16; + case at::ScalarType::BFloat16: + return onecclBfloat16; + case at::ScalarType::Int: + return onecclInt32; + case at::ScalarType::Long: + return onecclInt64; + case at::ScalarType::Char: + return onecclInt8; + case at::ScalarType::Byte: + return onecclUint8; + default: + throw std::runtime_error("Unsupported tensor data type for XCCL"); } } template -void createPreMulSum( - onecclRedOp_t* op, - const PreMulSumFactorT& factor, - const onecclComm_t& comm, - XcclApi* xccl_api) { +void createPreMulSum(onecclRedOp_t *op, const PreMulSumFactorT &factor, + const onecclComm_t &comm, XcclApi *xccl_api) { const bool is_tensor = std::holds_alternative(factor); - const auto residence = is_tensor ? onecclScalarDevice : onecclScalarHostImmediate; + const auto residence = + is_tensor ? onecclScalarDevice : onecclScalarHostImmediate; at::Tensor tensor = is_tensor ? std::get(factor) : at::Tensor(); T scalar_factor = is_tensor ? T{} : static_cast(std::get(factor)); - void* scalar = is_tensor ? tensor.data_ptr() : &scalar_factor; + void *scalar = is_tensor ? tensor.data_ptr() : &scalar_factor; - TORCH_INTERNAL_ASSERT( - is_tensor ? dataType == getXcclDataTypeInternal(tensor) - : dataType != onecclBfloat16, - "PreMulSum factor type must match input data type"); + TORCH_INTERNAL_ASSERT(is_tensor ? dataType == getXcclDataTypeInternal(tensor) + : dataType != onecclBfloat16, + "PreMulSum factor type must match input data type"); xccl_api->redOpCreatePreMulSum(op, scalar, dataType, residence, comm); } @@ -62,11 +57,9 @@ void createPreMulSum( TorchCommXCCL::RedOpRAII::RedOpRAII(onecclRedOp_t op) : xcclRedOp_(op), comm_(nullptr) {} -TorchCommXCCL::RedOpRAII::RedOpRAII( - const ReduceOp& op, - const onecclComm_t comm, - const onecclDataType_t dataType, - std::shared_ptr xccl_api) +TorchCommXCCL::RedOpRAII::RedOpRAII(const ReduceOp &op, const onecclComm_t comm, + const onecclDataType_t dataType, + std::shared_ptr xccl_api) : comm_(comm), xccl_api_(std::move(xccl_api)) { TORCH_INTERNAL_ASSERT( op == ReduceOp::RedOpType::PREMUL_SUM, @@ -78,27 +71,27 @@ TorchCommXCCL::RedOpRAII::RedOpRAII( return; } - const auto& factor = op.factor().value(); + const auto &factor = op.factor().value(); switch (dataType) { - case onecclFloat16: - createPreMulSum( - &xcclRedOp_, factor, comm, xccl_api_.get()); - break; - case onecclFloat32: - createPreMulSum( - &xcclRedOp_, factor, comm, xccl_api_.get()); - break; - case onecclBfloat16: - createPreMulSum( - &xcclRedOp_, factor, comm, xccl_api_.get()); - break; - case onecclFloat64: - createPreMulSum( - &xcclRedOp_, factor, comm, xccl_api_.get()); - break; - default: - throw std::runtime_error( - "PreMulSum Data type must be half, float, bfloat16 or double"); + case onecclFloat16: + createPreMulSum(&xcclRedOp_, factor, comm, + xccl_api_.get()); + break; + case onecclFloat32: + createPreMulSum(&xcclRedOp_, factor, comm, + xccl_api_.get()); + break; + case onecclBfloat16: + createPreMulSum(&xcclRedOp_, factor, comm, + xccl_api_.get()); + break; + case onecclFloat64: + createPreMulSum(&xcclRedOp_, factor, comm, + xccl_api_.get()); + break; + default: + throw std::runtime_error( + "PreMulSum Data type must be half, float, bfloat16 or double"); } } @@ -110,69 +103,53 @@ TorchCommXCCL::RedOpRAII::~RedOpRAII() { size_t TorchCommXCCL::wordSize(onecclDataType_t type) const { switch (type) { - case onecclChar: -#if XCCL_MAJOR >= 2 - // case onecclInt8: - case onecclUint8: -#endif -// #if HAVE_FP8 -// case onecclFloat8e4m3: -// case onecclFloat8e5m2: -// #endif - return 1; - case onecclHalf: -#if HAVE_BF16 - case onecclBfloat16: -#endif - // case onecclFloat16: - return 2; - case onecclInt: - case onecclFloat: -#if XCCL_MAJOR >= 2 - // case onecclInt32: - case onecclUint32: - // case onecclFloat32: -#endif - return 4; - case onecclInt64: - case onecclUint64: - case onecclDouble: - // case onecclFloat64: - return 8; - default: - return 0; + case onecclInt8: + case onecclUint8: + return 1; + case onecclFloat16: + case onecclBfloat16: + return 2; + case onecclInt32: + case onecclUint32: + case onecclFloat32: + return 4; + case onecclInt64: + case onecclUint64: + case onecclFloat64: + return 8; + default: + return 0; } } -onecclDataType_t TorchCommXCCL::getXcclDataType(const at::Tensor& tensor) { +onecclDataType_t TorchCommXCCL::getXcclDataType(const at::Tensor &tensor) { return getXcclDataTypeInternal(tensor); } -TorchCommXCCL::RedOpRAII TorchCommXCCL::getXcclReduceOp( - const ReduceOp& op, - const onecclComm_t comm, - const onecclDataType_t dataType) { +TorchCommXCCL::RedOpRAII +TorchCommXCCL::getXcclReduceOp(const ReduceOp &op, const onecclComm_t comm, + const onecclDataType_t dataType) { switch (op) { - case ReduceOp::RedOpType::SUM: - return onecclSum; - case ReduceOp::RedOpType::PRODUCT: - return onecclProd; - case ReduceOp::RedOpType::MIN: - return onecclMin; - case ReduceOp::RedOpType::MAX: - return onecclMax; - case ReduceOp::RedOpType::BAND: - return onecclSum; // XCCL doesn't have bitwise AND, using SUM as fallback - case ReduceOp::RedOpType::BOR: - return onecclSum; // XCCL doesn't have bitwise OR, using SUM as fallback - case ReduceOp::RedOpType::BXOR: - return onecclSum; // XCCL doesn't have bitwise XOR, using SUM as fallback - case ReduceOp::RedOpType::PREMUL_SUM: - return RedOpRAII(op, comm, dataType, xccl_api_); - case ReduceOp::RedOpType::AVG: - return onecclAvg; - default: - throw std::runtime_error("Unsupported reduce operation"); + case ReduceOp::RedOpType::SUM: + return onecclSum; + case ReduceOp::RedOpType::PRODUCT: + return onecclProd; + case ReduceOp::RedOpType::MIN: + return onecclMin; + case ReduceOp::RedOpType::MAX: + return onecclMax; + case ReduceOp::RedOpType::BAND: + return onecclSum; // XCCL doesn't have bitwise AND, using SUM as fallback + case ReduceOp::RedOpType::BOR: + return onecclSum; // XCCL doesn't have bitwise OR, using SUM as fallback + case ReduceOp::RedOpType::BXOR: + return onecclSum; // XCCL doesn't have bitwise XOR, using SUM as fallback + case ReduceOp::RedOpType::PREMUL_SUM: + return RedOpRAII(op, comm, dataType, xccl_api_); + case ReduceOp::RedOpType::AVG: + return onecclAvg; + default: + throw std::runtime_error("Unsupported reduce operation"); } } @@ -180,15 +157,15 @@ void TorchCommXCCL::checkWorkQueue(bool isMainThread) { TorchWorkXCCL::WorkStatus status = workq_.garbageCollect(isMainThread); switch (status) { - case TorchWorkXCCL::WorkStatus::TIMEDOUT: - comm_state_ = CommState::TIMEOUT; - break; - case TorchWorkXCCL::WorkStatus::ERROR: - comm_state_ = CommState::ERROR; - break; - default: - // For COMPLETED, NOT_STARTED, and INPROGRESS, no state change needed - break; + case TorchWorkXCCL::WorkStatus::TIMEDOUT: + comm_state_ = CommState::TIMEOUT; + break; + case TorchWorkXCCL::WorkStatus::ERROR: + comm_state_ = CommState::ERROR; + break; + default: + // For COMPLETED, NOT_STARTED, and INPROGRESS, no state change needed + break; } } @@ -201,8 +178,8 @@ void TorchCommXCCL::timeoutWatchdog() noexcept { std::unique_lock lock(timeout_mutex_); // Wait for a shorter interval to check work objects periodically // Wake up either after 1 second or immediately if shutdown is requested - timeout_cv_.wait_for( - lock, std::chrono::seconds(1), [this]() { return shutdown_.load(); }); + timeout_cv_.wait_for(lock, std::chrono::seconds(1), + [this]() { return shutdown_.load(); }); // If we're shutting down, exit the loop if (shutdown_) { @@ -238,12 +215,6 @@ void TorchCommXCCL::checkInitialized() const { } void TorchCommXCCL::checkAndAbortIfTimedOutOrError() { - // Nothing to check in graph capture mode - // XPU doesn't support graph capture yet - // if (getGraphCaptureMode()) { - // return; - // } - // First, check work queue status checkWorkQueue(true); @@ -270,133 +241,35 @@ void TorchCommXCCL::checkAndAbortIfTimedOutOrError() { } } -// bool TorchCommXCCL::getGraphCaptureMode() { -// xpuStream_t current_stream = -// xpu_api_->getCurrentXPUStream(device_.index()); -// xpuStreamCaptureStatus capture_status; - -// xpuError_t err = -// xpu_api_->streamIsCapturing(current_stream, &capture_status); -// if (err == xpuSuccess) { -// return capture_status == xpuStreamCaptureStatusActive; -// } - -// throw std::runtime_error( -// "Failed to check XPU stream capture status: " + -// std::string(xpu_api_->getErrorString(err))); -// } - -std::shared_ptr TorchCommXCCL::createWork( - xpuStream_t stream, - std::chrono::milliseconds timeout, - const std::vector& inputTensors) { +std::shared_ptr +TorchCommXCCL::createWork(xpuStream_t stream, std::chrono::milliseconds timeout, + const std::vector &inputTensors) { // Only create the work object without enqueuing it - auto work = std::make_shared( - shared_from_this(), stream, timeout, inputTensors, tracing_); + auto work = std::make_shared(shared_from_this(), stream, + timeout, inputTensors, tracing_); return work; } -void TorchCommXCCL::enqueueWork( - std::shared_ptr work, - xpuStream_t stream) { - // In graph capture mode, keep a reference to the work object to prevent - // premature destruction until the graph gets destroyed, organized per graph - // if (getGraphCaptureMode()) { - // xpuStreamCaptureStatus capture_status; - // unsigned long long graph_id; - // xpuGraph_t graph; - - // xpuError_t err = xpu_api_->streamGetCaptureInfo_v2( - // stream, &capture_status, &graph_id, &graph, nullptr, nullptr); - // if (err != xpuSuccess) { - // throw std::runtime_error( - // "Failed to get XPU stream capture info: " + - // std::string(xpu_api_->getErrorString(err))); - // } else if (capture_status == xpuStreamCaptureStatusActive) { - // std::lock_guard lock(graph_capture_work_mutex_); - - // // Check if this is the first work object for this graph - // bool is_first_work = graph_capture_work_refs_[graph_id].empty(); - - // // Add work reference to the per-graph container - // graph_capture_work_refs_[graph_id].push_back(work); - - // // If this is the first work object for this graph, set up automatic - // // cleanup - // if (is_first_work) { - // // Create cleanup data that will be passed to the callback - // auto* cleanup_data = new GraphCleanupData(this, graph_id); - - // // Create a XPU user object with our cleanup callback - // xpuUserObject_t user_object; - // err = xpu_api_->userObjectCreate( - // &user_object, - // cleanup_data, - // graphCleanupCallback, - // 1, // initial reference count - // xpuUserObjectNoDestructorSync); - // if (err != xpuSuccess) { - // // If we failed to create the user object, clean up manually - // delete cleanup_data; - // throw std::runtime_error( - // "Failed to create user object: " + - // std::string(xpu_api_->getErrorString(err))); - // } else { - // // Retain the user object in the graph so it gets cleaned up when the - // // graph is destroyed - // err = xpu_api_->graphRetainUserObject( - // graph, - // user_object, - // 1, // reference count - // xpuGraphUserObjectMove); - // if (err != xpuSuccess) { - // // If we failed to retain the user object, clean up manually - // delete cleanup_data; - // throw std::runtime_error( - // "Failed to retain user object: " + - // std::string(xpu_api_->getErrorString(err))); - // } - // } - // } - // } - // } else { +void TorchCommXCCL::enqueueWork(std::shared_ptr work, + xpuStream_t stream) { // Add work to stream's queue after events have been recorded workq_.enqueueWork(std::move(work), stream); - // } } -// // Static callback function for XPU user object cleanup -// void XPURT_CB TorchCommXCCL::graphCleanupCallback(void* userData) { -// auto* cleanup_data = static_cast(userData); -// if (cleanup_data == nullptr || cleanup_data->comm == nullptr) { -// throw std::runtime_error("Invalid cleanup data"); -// } - -// // Clear the work references for this graph -// std::lock_guard lock( -// cleanup_data->comm->graph_capture_work_mutex_); -// cleanup_data->comm->graph_capture_work_refs_.erase(cleanup_data->graph_id); - -// // Clean up the cleanup data itself -// delete cleanup_data; -// } - xpuStream_t TorchCommXCCL::getOperationStream(bool async_op) { if (async_op) { // Get current PyTorch XPU stream for this device - xpuStream_t current_stream = - xpu_api_->getCurrentXPUStream(device_.index()); + xpuStream_t current_stream = xpu_api_->getCurrentXPUStream(device_.index()); // Record event on current stream and wait for it on internal stream - XPU_CHECK( - xpu_api_, - xpu_api_->eventRecord(dependency_event_.value(), current_stream), - "Failed to record dependency event"); + XPU_CHECK(xpu_api_, + xpu_api_->eventRecord(dependency_event_.value(), current_stream), + "Failed to record dependency event"); - XPU_CHECK( - xpu_api_, - xpu_api_->streamWaitEvent(internal_stream_.value(), dependency_event_.value(), 0), - "Failed to make internal stream wait for dependency event"); + XPU_CHECK(xpu_api_, + xpu_api_->streamWaitEvent(internal_stream_.value(), + dependency_event_.value(), 0), + "Failed to make internal stream wait for dependency event"); return internal_stream_.value(); } else { @@ -405,7 +278,7 @@ xpuStream_t TorchCommXCCL::getOperationStream(bool async_op) { } } -void TorchCommXCCL::ensureTensorContiguous(const at::Tensor& tensor) { +void TorchCommXCCL::ensureTensorContiguous(const at::Tensor &tensor) { if (!tensor.is_contiguous()) { throw std::runtime_error("Tensor must be contiguous for XCCL operations"); } @@ -423,22 +296,20 @@ xpuEvent_t TorchCommXCCL::getEvent() { // Create new event if pool is empty xpuEvent_t event; - XPU_CHECK( - xpu_api_, - xpu_api_->eventCreateWithFlags(event, /*flags=*/0), - "Failed to create event"); + XPU_CHECK(xpu_api_, xpu_api_->eventCreateWithFlags(event, /*flags=*/0), + "Failed to create event"); return event; } -void TorchCommXCCL::returnEvent(xpuEvent_t&& event) { +void TorchCommXCCL::returnEvent(xpuEvent_t &&event) { std::lock_guard lock(event_pool_mutex_); if (event_pool_.size() < max_event_pool_size_) { event_pool_.push(std::move(event)); } else { // Pool is full, destroy the event - XPU_CHECK( - xpu_api_, xpu_api_->eventDestroy(event), "Failed to destroy event"); + XPU_CHECK(xpu_api_, xpu_api_->eventDestroy(event), + "Failed to destroy event"); } } diff --git a/comms/torchcomms/xccl/TorchWorkXCCL.cpp b/comms/torchcomms/xccl/TorchWorkXCCL.cpp index 7fd99c80..1d40674b 100644 --- a/comms/torchcomms/xccl/TorchWorkXCCL.cpp +++ b/comms/torchcomms/xccl/TorchWorkXCCL.cpp @@ -1,24 +1,18 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #include "comms/torchcomms/xccl/TorchWorkXCCL.hpp" -#include #include "comms/torchcomms/TorchCommLogging.hpp" #include "comms/torchcomms/xccl/TorchCommXCCL.hpp" +#include namespace torch { namespace comms { -TorchWorkXCCL::TorchWorkXCCL( - std::shared_ptr comm, - xpuStream_t stream, - std::chrono::milliseconds timeout_ms, - const std::vector& inputTensors, - std::shared_ptr tracing) - : inputTensors_(inputTensors), - comm_(std::move(comm)), - stream_(stream), - timeout_ms_(timeout_ms), - state_(WorkStatus::NOT_STARTED), +TorchWorkXCCL::TorchWorkXCCL(std::shared_ptr comm, + xpuStream_t stream, + std::chrono::milliseconds timeout_ms, + const std::vector &inputTensors, + std::shared_ptr tracing) + : inputTensors_(inputTensors), comm_(std::move(comm)), stream_(stream), + timeout_ms_(timeout_ms), state_(WorkStatus::NOT_STARTED), tracing_(std::move(tracing)) { // If not in graph capture mode, create the events for start and end // recording @@ -38,22 +32,18 @@ TorchWorkXCCL::~TorchWorkXCCL() { } void TorchWorkXCCL::recordStart() { - XPU_CHECK( - comm_->getXpuApi(), - comm_->getXpuApi()->eventRecord(start_event_, stream_), - "Failed to record start event"); + XPU_CHECK(comm_->getXpuApi(), + comm_->getXpuApi()->eventRecord(start_event_, stream_), + "Failed to record start event"); } void TorchWorkXCCL::recordEnd() { - XPU_CHECK( - comm_->getXpuApi(), - comm_->getXpuApi()->eventRecord(end_event_, stream_), - "Failed to record end event"); + XPU_CHECK(comm_->getXpuApi(), + comm_->getXpuApi()->eventRecord(end_event_, stream_), + "Failed to record end event"); } -bool TorchWorkXCCL::isCompleted() { - return state_ == WorkStatus::COMPLETED; -} +bool TorchWorkXCCL::isCompleted() { return state_ == WorkStatus::COMPLETED; } TorchWorkXCCL::WorkStatus TorchWorkXCCL::checkStatus() { // If already marked as completed, return COMPLETED @@ -71,9 +61,8 @@ TorchWorkXCCL::WorkStatus TorchWorkXCCL::checkStatus() { // Start event has completed, store the current time start_completed_time_ = std::chrono::steady_clock::now(); state_ = WorkStatus::INPROGRESS; - } else if ( - start_status != XPU_ERROR_NOT_READY && - start_status != XPU_ERROR_UNSUPPORTED) { + } else if (start_status != XPU_ERROR_NOT_READY && + start_status != XPU_ERROR_UNSUPPORTED) { // Some other error occurred with the start event TC_LOG(ERROR) << "XPU error during start event query: " << comm_->getXpuApi()->getErrorString(start_status) << " (" @@ -133,10 +122,9 @@ void TorchWorkXCCL::wait() { // Add a dependency from the work's stream to the current stream // This makes the current stream wait for the end_event_ recorded on the // work's stream - XPU_CHECK( - comm_->getXpuApi(), - comm_->getXpuApi()->streamWaitEvent(current_stream, end_event_, 0), - "Failed to make stream wait for event"); + XPU_CHECK(comm_->getXpuApi(), + comm_->getXpuApi()->streamWaitEvent(current_stream, end_event_, 0), + "Failed to make stream wait for event"); } } // namespace comms } // namespace torch diff --git a/comms/torchcomms/xccl/TorchWorkXCCL.hpp b/comms/torchcomms/xccl/TorchWorkXCCL.hpp index 0c0e9fbc..c51f04d8 100644 --- a/comms/torchcomms/xccl/TorchWorkXCCL.hpp +++ b/comms/torchcomms/xccl/TorchWorkXCCL.hpp @@ -1,5 +1,3 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #pragma once #include @@ -11,7 +9,6 @@ #include #include -// #include // @manual=third-party//xpu:xpu-lazy #include "comms/torchcomms/TorchCommTracing.hpp" #include "comms/torchcomms/TorchWork.hpp" #include "comms/torchcomms/device/XpuApi.hpp" @@ -29,48 +26,44 @@ class TorchCommXCCLTest; } class TorchWorkXCCL : public TorchWork { - public: +public: // Status of a work object enum class WorkStatus { NOT_STARTED, // Work has not started yet - INPROGRESS, // Work is still in progress, - COMPLETED, // Work has completed successfully - TIMEDOUT, // Work has timed out - ERROR // Work has encountered an error + INPROGRESS, // Work is still in progress, + COMPLETED, // Work has completed successfully + TIMEDOUT, // Work has timed out + ERROR // Work has encountered an error }; - TorchWorkXCCL( - std::shared_ptr comm, - xpuStream_t stream, - std::chrono::milliseconds timeout_ms, - const std::vector& inputTensors, - std::shared_ptr tracing); + TorchWorkXCCL(std::shared_ptr comm, xpuStream_t stream, + std::chrono::milliseconds timeout_ms, + const std::vector &inputTensors, + std::shared_ptr tracing); ~TorchWorkXCCL() override; // Delete copy and move operations - TorchWorkXCCL(const TorchWorkXCCL&) = delete; - TorchWorkXCCL(TorchWorkXCCL&&) = delete; - TorchWorkXCCL& operator=(const TorchWorkXCCL&) = delete; - TorchWorkXCCL& operator=(TorchWorkXCCL&&) = delete; + TorchWorkXCCL(const TorchWorkXCCL &) = delete; + TorchWorkXCCL(TorchWorkXCCL &&) = delete; + TorchWorkXCCL &operator=(const TorchWorkXCCL &) = delete; + TorchWorkXCCL &operator=(TorchWorkXCCL &&) = delete; // Override virtual functions from TorchWork bool isCompleted() override; void wait() override; - protected: +protected: void recordStart(); void recordEnd(); friend class TorchCommXCCL; friend class TorchWorkXCCLQueue; - private: +private: // Check the status of the work object WorkStatus checkStatus(); - std::chrono::milliseconds getTimeout() const { - return timeout_ms_; - } + std::chrono::milliseconds getTimeout() const { return timeout_ms_; } std::vector inputTensors_; std::shared_ptr comm_; @@ -88,7 +81,7 @@ class TorchWorkXCCL : public TorchWork { }; class TorchWorkXCCLQueue { - public: +public: TorchWorkXCCLQueue() = default; ~TorchWorkXCCLQueue() = default; @@ -97,7 +90,7 @@ class TorchWorkXCCLQueue { TorchWorkXCCL::WorkStatus finalize(); void enqueueWork(std::shared_ptr work, xpuStream_t stream); - private: +private: std::unordered_map>> stream_work_queues_; std::vector> completed_work_queue_; diff --git a/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp b/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp index d1d0267c..20d94d22 100644 --- a/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp +++ b/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp @@ -1,12 +1,10 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #include "comms/torchcomms/xccl/TorchWorkXCCL.hpp" namespace torch { namespace comms { -TorchWorkXCCL::WorkStatus TorchWorkXCCLQueue::garbageCollect( - bool isMainThread) { +TorchWorkXCCL::WorkStatus +TorchWorkXCCLQueue::garbageCollect(bool isMainThread) { std::lock_guard lock(work_queues_mutex_); TorchWorkXCCL::WorkStatus last_status = TorchWorkXCCL::WorkStatus::COMPLETED; @@ -16,7 +14,7 @@ TorchWorkXCCL::WorkStatus TorchWorkXCCLQueue::garbageCollect( // Use an iterator to safely remove empty queues while iterating auto it = stream_work_queues_.begin(); while (it != stream_work_queues_.end()) { - auto& work_queue = it->second; + auto &work_queue = it->second; while (!work_queue.empty()) { // Get the first work object in the queue @@ -31,9 +29,8 @@ TorchWorkXCCL::WorkStatus TorchWorkXCCLQueue::garbageCollect( work_queue.pop(); completed_work_queue_.push_back(work); // Continue to the next element in the queue - } else if ( - status == TorchWorkXCCL::WorkStatus::TIMEDOUT || - status == TorchWorkXCCL::WorkStatus::ERROR) { + } else if (status == TorchWorkXCCL::WorkStatus::TIMEDOUT || + status == TorchWorkXCCL::WorkStatus::ERROR) { // Return the error status immediately return status; } else { @@ -88,9 +85,8 @@ TorchWorkXCCL::WorkStatus TorchWorkXCCLQueue::finalize() { return status; } -void TorchWorkXCCLQueue::enqueueWork( - std::shared_ptr work, - xpuStream_t stream) { +void TorchWorkXCCLQueue::enqueueWork(std::shared_ptr work, + xpuStream_t stream) { // Add work to stream's queue after events have been recorded std::lock_guard lock(work_queues_mutex_); stream_work_queues_[stream].push(work); diff --git a/comms/torchcomms/xccl/XcclApi.cpp b/comms/torchcomms/xccl/XcclApi.cpp index 96d59bbc..ec0c1f7b 100644 --- a/comms/torchcomms/xccl/XcclApi.cpp +++ b/comms/torchcomms/xccl/XcclApi.cpp @@ -1,13 +1,10 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. - #include "comms/torchcomms/xccl/XcclApi.hpp" #include "comms/torchcomms/TorchCommLogging.hpp" namespace torch { namespace comms { -// DefaultXcclApi implementation -const char* DefaultXcclApi::getErrorString(onecclResult_t result) { +const char *DefaultXcclApi::getErrorString(onecclResult_t result) { return onecclGetErrorString(result); } @@ -15,16 +12,15 @@ onecclResult_t DefaultXcclApi::setDevice(int device) { return onecclSetDevice(device); } -onecclResult_t DefaultXcclApi::getUniqueId(onecclUniqueId* uniqueId) { +onecclResult_t DefaultXcclApi::getUniqueId(onecclUniqueId *uniqueId) { return onecclGetUniqueId(uniqueId); } -onecclResult_t DefaultXcclApi::commInitRankConfig( - onecclComm_t* comm, - int nranks, - onecclUniqueId commId, - int rank, - onecclConfig_t* config) { +onecclResult_t DefaultXcclApi::commInitRankConfig(onecclComm_t *comm, + int nranks, + onecclUniqueId commId, + int rank, + onecclConfig_t *config) { return onecclCommInitRankConfig(comm, nranks, commId, rank, config); } @@ -37,159 +33,121 @@ onecclResult_t DefaultXcclApi::commAbort(onecclComm_t comm) { return onecclNotImplemented; } -onecclResult_t DefaultXcclApi::commGetAsyncError( - onecclComm_t comm, - onecclResult_t* asyncError) { +onecclResult_t DefaultXcclApi::commGetAsyncError(onecclComm_t comm, + onecclResult_t *asyncError) { // return onecclCommGetAsyncError(comm); return onecclNotImplemented; } -onecclResult_t DefaultXcclApi::commSplit( - onecclComm_t comm, - int color, - int key, - onecclComm_t* newcomm, - onecclConfig_t* config) { +onecclResult_t DefaultXcclApi::commSplit(onecclComm_t comm, int color, int key, + onecclComm_t *newcomm, + onecclConfig_t *config) { return onecclCommSplit(comm, color, key, newcomm, config); } -onecclResult_t DefaultXcclApi::commRegister( - onecclComm_t comm, - void* buffer, - size_t size, - void** handle) { +onecclResult_t DefaultXcclApi::commRegister(onecclComm_t comm, void *buffer, + size_t size, void **handle) { // return onecclCommRegister(comm, buffer, size, handle); return onecclNotImplemented; } -onecclResult_t DefaultXcclApi::commDeregister(onecclComm_t comm, void* handle) { +onecclResult_t DefaultXcclApi::commDeregister(onecclComm_t comm, void *handle) { // return onecclCommDeregister(comm, handle); - return onecclNotImplemented; -} - -onecclResult_t DefaultXcclApi::send( - const void* sendbuff, - size_t count, - onecclDataType_t datatype, - int peer, - onecclComm_t comm, - xpuStream_t stream) { - return onecclSend(const_cast(sendbuff), count, datatype, peer, comm, stream); -} - -onecclResult_t DefaultXcclApi::recv( - void* recvbuff, - size_t count, - onecclDataType_t datatype, - int peer, - onecclComm_t comm, - xpuStream_t stream) { - return onecclRecv(recvbuff, count, datatype, peer, comm, stream); + return onecclNotImplemented; } -onecclResult_t DefaultXcclApi::broadcast( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - int root, - onecclComm_t comm, - xpuStream_t stream) { - return onecclBroadcast(const_cast(sendbuff), recvbuff, count, datatype, root, comm, stream); -} - -onecclResult_t DefaultXcclApi::bcast( - void* buff, - size_t count, - onecclDataType_t datatype, - int root, - onecclComm_t comm, - xpuStream_t stream) { - return onecclBroadcast(buff, buff, count, datatype, root, comm, stream); +onecclResult_t DefaultXcclApi::send(const void *sendbuff, size_t count, + onecclDataType_t datatype, int peer, + onecclComm_t comm, xpuStream_t stream) { + return onecclSend(const_cast(sendbuff), count, datatype, peer, comm, + stream); +} + +onecclResult_t DefaultXcclApi::recv(void *recvbuff, size_t count, + onecclDataType_t datatype, int peer, + onecclComm_t comm, xpuStream_t stream) { + return onecclRecv(recvbuff, count, datatype, peer, comm, stream); } -onecclResult_t DefaultXcclApi::allReduce( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - onecclRedOp_t op, - onecclComm_t comm, - xpuStream_t stream) { - return onecclAllReduce(const_cast(sendbuff), recvbuff, count, datatype, op, comm, stream); +onecclResult_t DefaultXcclApi::broadcast(const void *sendbuff, void *recvbuff, + size_t count, + onecclDataType_t datatype, int root, + onecclComm_t comm, + xpuStream_t stream) { + return onecclBroadcast(const_cast(sendbuff), recvbuff, count, + datatype, root, comm, stream); } -onecclResult_t DefaultXcclApi::reduce( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - onecclRedOp_t op, - int root, - onecclComm_t comm, - xpuStream_t stream) { - return onecclReduce( - const_cast(sendbuff), recvbuff, count, datatype, op, root, comm, stream); +onecclResult_t DefaultXcclApi::bcast(void *buff, size_t count, + onecclDataType_t datatype, int root, + onecclComm_t comm, xpuStream_t stream) { + return onecclBroadcast(buff, buff, count, datatype, root, comm, stream); } -onecclResult_t DefaultXcclApi::allGather( - const void* sendbuff, - void* recvbuff, - size_t sendcount, - onecclDataType_t datatype, - onecclComm_t comm, - xpuStream_t stream) { - return onecclAllGather(const_cast(sendbuff), recvbuff, sendcount, datatype, comm, stream); +onecclResult_t DefaultXcclApi::allReduce(const void *sendbuff, void *recvbuff, + size_t count, + onecclDataType_t datatype, + onecclRedOp_t op, onecclComm_t comm, + xpuStream_t stream) { + return onecclAllReduce(const_cast(sendbuff), recvbuff, count, + datatype, op, comm, stream); } -onecclResult_t DefaultXcclApi::reduceScatter( - const void* sendbuff, - void* recvbuff, - size_t recvcount, - onecclDataType_t datatype, - onecclRedOp_t op, - onecclComm_t comm, - xpuStream_t stream) { - return onecclReduceScatter( - const_cast(sendbuff), recvbuff, recvcount, datatype, op, comm, stream); +onecclResult_t DefaultXcclApi::reduce(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclRedOp_t op, int root, + onecclComm_t comm, xpuStream_t stream) { + return onecclReduce(const_cast(sendbuff), recvbuff, count, datatype, + op, root, comm, stream); } -onecclResult_t DefaultXcclApi::allToAll( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - onecclComm_t comm, - xpuStream_t stream) { - return onecclAllToAll(const_cast(sendbuff), recvbuff, count, datatype, comm, stream); +onecclResult_t DefaultXcclApi::allGather(const void *sendbuff, void *recvbuff, + size_t sendcount, + onecclDataType_t datatype, + onecclComm_t comm, + xpuStream_t stream) { + return onecclAllGather(const_cast(sendbuff), recvbuff, sendcount, + datatype, comm, stream); } -onecclResult_t DefaultXcclApi::groupStart() { - return onecclGroupStart(); +onecclResult_t DefaultXcclApi::reduceScatter(const void *sendbuff, + void *recvbuff, size_t recvcount, + onecclDataType_t datatype, + onecclRedOp_t op, + onecclComm_t comm, + xpuStream_t stream) { + return onecclReduceScatter(const_cast(sendbuff), recvbuff, recvcount, + datatype, op, comm, stream); } -onecclResult_t DefaultXcclApi::groupEnd() { - return onecclGroupEnd(); +onecclResult_t DefaultXcclApi::allToAll(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclComm_t comm, xpuStream_t stream) { + return onecclAllToAll(const_cast(sendbuff), recvbuff, count, datatype, + comm, stream); } -onecclResult_t DefaultXcclApi::commUserRank(const onecclComm_t comm, int* myRank) { +onecclResult_t DefaultXcclApi::groupStart() { return onecclGroupStart(); } + +onecclResult_t DefaultXcclApi::groupEnd() { return onecclGroupEnd(); } + +onecclResult_t DefaultXcclApi::commUserRank(const onecclComm_t comm, + int *myRank) { return onecclCommUserRank(comm, myRank); } -onecclResult_t DefaultXcclApi::commCount(const onecclComm_t comm, int* count) { +onecclResult_t DefaultXcclApi::commCount(const onecclComm_t comm, int *count) { return onecclCommCount(comm, count); } onecclResult_t DefaultXcclApi::redOpCreatePreMulSum( - onecclRedOp_t* op, - void* scalar, - onecclDataType_t datatype, - onecclScalarResidence_t residence, - onecclComm_t comm) { + onecclRedOp_t *op, void *scalar, onecclDataType_t datatype, + onecclScalarResidence_t residence, onecclComm_t comm) { return onecclRedOpCreatePreMulSum(op, scalar, datatype, residence, comm); } -onecclResult_t DefaultXcclApi::redOpDestroy(onecclRedOp_t op, onecclComm_t comm) { +onecclResult_t DefaultXcclApi::redOpDestroy(onecclRedOp_t op, + onecclComm_t comm) { return onecclRedOpDestroy(op, comm); } diff --git a/comms/torchcomms/xccl/XcclApi.hpp b/comms/torchcomms/xccl/XcclApi.hpp index a6e0b577..9e17bb26 100644 --- a/comms/torchcomms/xccl/XcclApi.hpp +++ b/comms/torchcomms/xccl/XcclApi.hpp @@ -9,138 +9,90 @@ namespace torch { namespace comms { class XcclApi { - public: +public: virtual ~XcclApi() = default; - virtual const char* getErrorString(onecclResult_t result) = 0; + virtual const char *getErrorString(onecclResult_t result) = 0; - // Device management virtual onecclResult_t setDevice(int device) = 0; - // Unique ID generation - virtual onecclResult_t getUniqueId(onecclUniqueId* uniqueId) = 0; + virtual onecclResult_t getUniqueId(onecclUniqueId *uniqueId) = 0; - // Communicator management - virtual onecclResult_t commInitRankConfig( - onecclComm_t* comm, - int nranks, - onecclUniqueId commId, - int rank, - onecclConfig_t* config) = 0; + virtual onecclResult_t commInitRankConfig(onecclComm_t *comm, int nranks, + onecclUniqueId commId, int rank, + onecclConfig_t *config) = 0; virtual onecclResult_t commDestroy(onecclComm_t comm) = 0; virtual onecclResult_t commAbort(onecclComm_t comm) = 0; - virtual onecclResult_t commGetAsyncError( - onecclComm_t comm, - onecclResult_t* asyncError) = 0; + virtual onecclResult_t commGetAsyncError(onecclComm_t comm, + onecclResult_t *asyncError) = 0; - virtual onecclResult_t commSplit( - onecclComm_t comm, - int color, - int key, - onecclComm_t* newcomm, - onecclConfig_t* config) = 0; + virtual onecclResult_t commSplit(onecclComm_t comm, int color, int key, + onecclComm_t *newcomm, + onecclConfig_t *config) = 0; - // Memory registration - virtual onecclResult_t - commRegister(onecclComm_t comm, void* buffer, size_t size, void** handle) = 0; + virtual onecclResult_t commRegister(onecclComm_t comm, void *buffer, + size_t size, void **handle) = 0; - virtual onecclResult_t commDeregister(onecclComm_t comm, void* handle) = 0; + virtual onecclResult_t commDeregister(onecclComm_t comm, void *handle) = 0; // Point-to-point operations - virtual onecclResult_t send( - const void* sendbuff, - size_t count, - onecclDataType_t datatype, - int peer, - onecclComm_t comm, - xpuStream_t stream) = 0; - - virtual onecclResult_t recv( - void* recvbuff, - size_t count, - onecclDataType_t datatype, - int peer, - onecclComm_t comm, - xpuStream_t stream) = 0; + virtual onecclResult_t send(const void *sendbuff, size_t count, + onecclDataType_t datatype, int peer, + onecclComm_t comm, xpuStream_t stream) = 0; + + virtual onecclResult_t recv(void *recvbuff, size_t count, + onecclDataType_t datatype, int peer, + onecclComm_t comm, xpuStream_t stream) = 0; // Collective operations - virtual onecclResult_t broadcast( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - int root, - onecclComm_t comm, - xpuStream_t stream) = 0; - - virtual onecclResult_t bcast( - void* buff, - size_t count, - onecclDataType_t datatype, - int root, - onecclComm_t comm, - xpuStream_t stream) = 0; - - virtual onecclResult_t allReduce( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - onecclRedOp_t op, - onecclComm_t comm, - xpuStream_t stream) = 0; - - virtual onecclResult_t reduce( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - onecclRedOp_t op, - int root, - onecclComm_t comm, - xpuStream_t stream) = 0; - - virtual onecclResult_t allGather( - const void* sendbuff, - void* recvbuff, - size_t sendcount, - onecclDataType_t datatype, - onecclComm_t comm, - xpuStream_t stream) = 0; - - virtual onecclResult_t reduceScatter( - const void* sendbuff, - void* recvbuff, - size_t recvcount, - onecclDataType_t datatype, - onecclRedOp_t op, - onecclComm_t comm, - xpuStream_t stream) = 0; - - virtual onecclResult_t allToAll( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - onecclComm_t comm, - xpuStream_t stream) = 0; + virtual onecclResult_t broadcast(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + int root, onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t bcast(void *buff, size_t count, + onecclDataType_t datatype, int root, + onecclComm_t comm, xpuStream_t stream) = 0; + + virtual onecclResult_t allReduce(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclRedOp_t op, onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t reduce(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclRedOp_t op, int root, onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t allGather(const void *sendbuff, void *recvbuff, + size_t sendcount, onecclDataType_t datatype, + onecclComm_t comm, xpuStream_t stream) = 0; + + virtual onecclResult_t reduceScatter(const void *sendbuff, void *recvbuff, + size_t recvcount, + onecclDataType_t datatype, + onecclRedOp_t op, onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t allToAll(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclComm_t comm, xpuStream_t stream) = 0; // Group operations virtual onecclResult_t groupStart() = 0; virtual onecclResult_t groupEnd() = 0; - virtual onecclResult_t commUserRank(const onecclComm_t comm, int* userRank) = 0; - virtual onecclResult_t commCount(const onecclComm_t comm, int* count) = 0; + virtual onecclResult_t commUserRank(const onecclComm_t comm, + int *userRank) = 0; + virtual onecclResult_t commCount(const onecclComm_t comm, int *count) = 0; - virtual onecclResult_t redOpCreatePreMulSum( - onecclRedOp_t* op, - void* scalar, - onecclDataType_t datatype, - onecclScalarResidence_t residence, - onecclComm_t comm) = 0; + virtual onecclResult_t redOpCreatePreMulSum(onecclRedOp_t *op, void *scalar, + onecclDataType_t datatype, + onecclScalarResidence_t residence, + onecclComm_t comm) = 0; virtual onecclResult_t redOpDestroy(onecclRedOp_t op, onecclComm_t comm) = 0; }; @@ -148,140 +100,88 @@ class XcclApi { * Default implementation that calls the underlying XCCL APIs directly. */ class DefaultXcclApi : public XcclApi { - public: +public: ~DefaultXcclApi() override = default; // Error handling - const char* getErrorString(onecclResult_t result) override; + const char *getErrorString(onecclResult_t result) override; // Device management onecclResult_t setDevice(int device) override; // Unique ID generation - onecclResult_t getUniqueId(onecclUniqueId* uniqueId) override; + onecclResult_t getUniqueId(onecclUniqueId *uniqueId) override; // Communicator management - onecclResult_t commInitRankConfig( - onecclComm_t* comm, - int nranks, - onecclUniqueId commId, - int rank, - onecclConfig_t* config) override; + onecclResult_t commInitRankConfig(onecclComm_t *comm, int nranks, + onecclUniqueId commId, int rank, + onecclConfig_t *config) override; onecclResult_t commDestroy(onecclComm_t comm) override; onecclResult_t commAbort(onecclComm_t comm) override; - onecclResult_t commGetAsyncError(onecclComm_t comm, onecclResult_t* asyncError) - override; + onecclResult_t commGetAsyncError(onecclComm_t comm, + onecclResult_t *asyncError) override; - onecclResult_t commSplit( - onecclComm_t comm, - int color, - int key, - onecclComm_t* newcomm, - onecclConfig_t* config) override; + onecclResult_t commSplit(onecclComm_t comm, int color, int key, + onecclComm_t *newcomm, + onecclConfig_t *config) override; - onecclResult_t commRegister( - onecclComm_t comm, - void* buffer, - size_t size, - void** handle) override; + onecclResult_t commRegister(onecclComm_t comm, void *buffer, size_t size, + void **handle) override; - onecclResult_t commDeregister(onecclComm_t comm, void* handle) override; + onecclResult_t commDeregister(onecclComm_t comm, void *handle) override; // Point-to-point operations - onecclResult_t send( - const void* sendbuff, - size_t count, - onecclDataType_t datatype, - int peer, - onecclComm_t comm, - xpuStream_t stream) override; - - onecclResult_t recv( - void* recvbuff, - size_t count, - onecclDataType_t datatype, - int peer, - onecclComm_t comm, - xpuStream_t stream) override; + onecclResult_t send(const void *sendbuff, size_t count, + onecclDataType_t datatype, int peer, onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t recv(void *recvbuff, size_t count, onecclDataType_t datatype, + int peer, onecclComm_t comm, xpuStream_t stream) override; // Collective operations - onecclResult_t broadcast( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - int root, - onecclComm_t comm, - xpuStream_t stream) override; - - onecclResult_t bcast( - void* buff, - size_t count, - onecclDataType_t datatype, - int root, - onecclComm_t comm, - xpuStream_t stream) override; - - onecclResult_t allReduce( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - onecclRedOp_t op, - onecclComm_t comm, - xpuStream_t stream) override; - - onecclResult_t reduce( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - onecclRedOp_t op, - int root, - onecclComm_t comm, - xpuStream_t stream) override; - - onecclResult_t allGather( - const void* sendbuff, - void* recvbuff, - size_t sendcount, - onecclDataType_t datatype, - onecclComm_t comm, - xpuStream_t stream) override; - - onecclResult_t reduceScatter( - const void* sendbuff, - void* recvbuff, - size_t recvcount, - onecclDataType_t datatype, - onecclRedOp_t op, - onecclComm_t comm, - xpuStream_t stream) override; - - onecclResult_t allToAll( - const void* sendbuff, - void* recvbuff, - size_t count, - onecclDataType_t datatype, - onecclComm_t comm, - xpuStream_t stream) override; + onecclResult_t broadcast(const void *sendbuff, void *recvbuff, size_t count, + onecclDataType_t datatype, int root, + onecclComm_t comm, xpuStream_t stream) override; + + onecclResult_t bcast(void *buff, size_t count, onecclDataType_t datatype, + int root, onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t allReduce(const void *sendbuff, void *recvbuff, size_t count, + onecclDataType_t datatype, onecclRedOp_t op, + onecclComm_t comm, xpuStream_t stream) override; + + onecclResult_t reduce(const void *sendbuff, void *recvbuff, size_t count, + onecclDataType_t datatype, onecclRedOp_t op, int root, + onecclComm_t comm, xpuStream_t stream) override; + + onecclResult_t allGather(const void *sendbuff, void *recvbuff, + size_t sendcount, onecclDataType_t datatype, + onecclComm_t comm, xpuStream_t stream) override; + + onecclResult_t reduceScatter(const void *sendbuff, void *recvbuff, + size_t recvcount, onecclDataType_t datatype, + onecclRedOp_t op, onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t allToAll(const void *sendbuff, void *recvbuff, size_t count, + onecclDataType_t datatype, onecclComm_t comm, + xpuStream_t stream) override; // Group operations onecclResult_t groupStart() override; onecclResult_t groupEnd() override; - onecclResult_t commUserRank(const onecclComm_t comm, int* userRank) override; - onecclResult_t commCount(const onecclComm_t comm, int* count) override; + onecclResult_t commUserRank(const onecclComm_t comm, int *userRank) override; + onecclResult_t commCount(const onecclComm_t comm, int *count) override; - onecclResult_t redOpCreatePreMulSum( - onecclRedOp_t* op, - void* scalar, - onecclDataType_t datatype, - onecclScalarResidence_t residence, - onecclComm_t comm) override; + onecclResult_t redOpCreatePreMulSum(onecclRedOp_t *op, void *scalar, + onecclDataType_t datatype, + onecclScalarResidence_t residence, + onecclComm_t comm) override; onecclResult_t redOpDestroy(onecclRedOp_t op, onecclComm_t comm) override; }; From 44fb740a360de8847d780d5555a64aef97703c99 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 3 Nov 2025 16:19:27 +0800 Subject: [PATCH 03/10] typo --- comms/torchcomms/xccl/TorchCommXCCL.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comms/torchcomms/xccl/TorchCommXCCL.hpp b/comms/torchcomms/xccl/TorchCommXCCL.hpp index c813fe25..7914727c 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.hpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.hpp @@ -137,7 +137,7 @@ class TorchCommXCCL : public TorchCommBackend, // friend class CachingAllocatorHookImpl; friend class TorchCommWindowXCCL; - // Getter for CUDA API (for friend classes) + // Getter for XPU API (for friend classes) XpuApi *getXpuApi() const { return xpu_api_.get(); } // Getter for XCCL API (for friend classes) @@ -146,7 +146,7 @@ class TorchCommXCCL : public TorchCommBackend, // Method to override the XCCL API implementation for testing void setXcclApi(std::shared_ptr api) { xccl_api_ = std::move(api); } - // Method to override the CUDA API implementation for testing + // Method to override the XPU API implementation for testing void setXpuApi(std::shared_ptr api) { xpu_api_ = std::move(api); } const CommOptions &getOptions() const override { return options_; } @@ -254,7 +254,7 @@ class TorchCommXCCL : public TorchCommBackend, std::optional internal_stream_; // Initialized in init() std::optional dependency_event_; // Pre-allocated event for stream dependencies - void *barrier_buffer_{}; // Pre-allocated CUDA buffer for barrier operations + void *barrier_buffer_{}; // Pre-allocated XPU buffer for barrier operations enum class InitializationState { UNINITIALIZED, INITIALIZED, From 7b2fe76b7a893144ad4656aa8a9bc4cc4943080b Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Thu, 6 Nov 2025 17:10:55 +0800 Subject: [PATCH 04/10] revert split support and leave for seperate PR --- comms/torchcomms/device/XpuApi.cpp | 9 -- comms/torchcomms/device/XpuApi.hpp | 3 - comms/torchcomms/xccl/TorchCommXCCL.cpp | 109 +------------------ comms/torchcomms/xccl/TorchCommXCCL.hpp | 2 - comms/torchcomms/xccl/TorchCommXCCLUtils.cpp | 6 +- 5 files changed, 3 insertions(+), 126 deletions(-) diff --git a/comms/torchcomms/device/XpuApi.cpp b/comms/torchcomms/device/XpuApi.cpp index 18f80143..1a5dc7c8 100644 --- a/comms/torchcomms/device/XpuApi.cpp +++ b/comms/torchcomms/device/XpuApi.cpp @@ -33,10 +33,6 @@ xpu_result_t DefaultXpuApi::getDeviceProperties(xpuDeviceProp* prop, int device) // Get memory info prop->totalGlobalMem = sycl_device.get_info(); - // Set version info (XPU doesn't have major/minor version like CUDA) - prop->major = 1; - prop->minor = 0; - // Get compute capabilities auto max_work_group_size = sycl_device.get_info(); auto max_work_item_sizes = sycl_device.get_info>(); @@ -48,11 +44,6 @@ xpu_result_t DefaultXpuApi::getDeviceProperties(xpuDeviceProp* prop, int device) prop->maxThreadsDim[1] = max_work_item_sizes[1]; prop->maxThreadsDim[2] = max_work_item_sizes[2]; - // Max grid size - not directly available in SYCL, use reasonable defaults - prop->maxGridSize[0] = 2147483647; - prop->maxGridSize[1] = 65535; - prop->maxGridSize[2] = 65535; - return XPU_SUCCESS; } catch (const std::exception& e) { return XPU_ERROR_INVALID_VALUE; diff --git a/comms/torchcomms/device/XpuApi.hpp b/comms/torchcomms/device/XpuApi.hpp index 1cc14e8b..b07d539d 100644 --- a/comms/torchcomms/device/XpuApi.hpp +++ b/comms/torchcomms/device/XpuApi.hpp @@ -15,12 +15,9 @@ using xpuEvent_t = ::at::xpu::XPUEvent; struct xpuDeviceProp { char name[256]; size_t totalGlobalMem; - int major; - int minor; int multiProcessorCount; int maxThreadsPerBlock; int maxThreadsDim[3]; - int maxGridSize[3]; }; // Graph-related types (placeholder - unsupported in XPU) diff --git a/comms/torchcomms/xccl/TorchCommXCCL.cpp b/comms/torchcomms/xccl/TorchCommXCCL.cpp index 5fb91e49..06c8f1b1 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.cpp @@ -1056,114 +1056,7 @@ TorchCommXCCL::gather(const std::vector &output_tensor_list, std::shared_ptr TorchCommXCCL::split(const std::vector &ranks, const std::string &name, const CommOptions &options) { - // Validate the ranks list - checkAndAbortIfTimedOutOrError(); - std::unordered_set rank_seen; - for (int rank : ranks) { - if (rank < 0 || rank >= comm_size_) { - throw std::runtime_error("Invalid rank " + std::to_string(rank) + - " in ranks. Valid ranks are 0 to " + - std::to_string(comm_size_ - 1)); - } - if (rank_seen.find(rank) != rank_seen.end()) { - throw std::runtime_error("Rank " + std::to_string(rank) + - " appears multiple times in ranks"); - } - rank_seen.insert(rank); - } - - // Determine the color for this rank - int color; - int new_rank = -1; // Rank within the new communicator - - if (ranks.empty()) { - // Empty list means exclude all ranks - use XCCL_SPLIT_NOCOLOR -#ifdef XCCL_SPLIT_NOCOLOR - color = XCCL_SPLIT_NOCOLOR; -#else - throw std::runtime_error("XCCL_SPLIT_NOCOLOR is not defined"); -#endif - new_rank = -1; // Will not participate in new communicator - } else { - // Check if current rank is in the non-empty list - auto it = std::find(ranks.begin(), ranks.end(), rank_); - if (it == ranks.end()) { - // Current rank is not in the non-empty list - this is an error - throw std::runtime_error("Current rank " + std::to_string(rank_) + - " is not included in the provided ranks list"); - } - // Set color to the lowest rank in the group and calculate new rank - color = *std::min_element(ranks.begin(), ranks.end()); - new_rank = static_cast(std::distance(ranks.begin(), it)); - } - - // Create a new XCCL communicator - onecclComm_t new_comm; - onecclConfig_t config = ONECCL_CONFIG_INITIALIZER; - // Note: oneCCL does not have a commName field like NCCL - - // Populate XCCL config from user-provided hints - populateXcclConfigFromHints(config, options, name); - - onecclResult_t result = - xccl_api_->commSplit(xccl_comm_, color, new_rank, &new_comm, &config); - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL split failed", result); - } - if (new_rank == -1) { - return nullptr; // Rank is not in the group, return nullptr - } - - auto new_torchcomm = - std::shared_ptr(new TorchCommXCCL(new_comm)); - new_torchcomm->xccl_api_ = xccl_api_; - new_torchcomm->xpu_api_ = xpu_api_; - new_torchcomm->init(device_, name, options); - - return new_torchcomm; -} - -void TorchCommXCCL::register_address( - const TorchCommXCCL::AddressWithLen &addr) { - // We got a register after we got rid of the comm. Is this a fatal error? - if (!xccl_comm_) { - return; - } - - if (memoryRegistrationHandles_.contains(addr.addr)) { - throw std::runtime_error("Memory already registered with XCCL"); - } - void *handle = nullptr; - onecclResult_t result = - xccl_api_->commRegister(xccl_comm_, addr.addr, addr.len, &handle); - if (result != onecclSuccess) { - throw std::runtime_error("Failed to register memory with XCCL: " + - std::string(onecclGetErrorString(result))); - } - memoryRegistrationHandles_.emplace(addr.addr, RegistrationHandle(handle)); -} - -void TorchCommXCCL::deregister_address(const TorchCommXCCL::Address &addr) { - // We got a deregister after we got rid of the comm. Is this a fatal error? - if (!xccl_comm_) { - return; - } - - auto it = memoryRegistrationHandles_.find(addr.addr); - if (it == memoryRegistrationHandles_.end()) { - // it' possible that the memory was registered for a different comm, - // however failed registration for this comm. - return; - } - - void *handle = it->second.regHandle; - onecclResult_t result = xccl_api_->commDeregister(xccl_comm_, handle); - if (result != onecclSuccess) { - throw std::runtime_error("Failed to deregister memory with XCCL: " + - std::string(xccl_api_->getErrorString(result))); - } - - memoryRegistrationHandles_.erase(it); + throw std::runtime_error("Split is not supported now in XCCL"); } XCCLException::XCCLException(XcclApi &xccl_api, const std::string &message, diff --git a/comms/torchcomms/xccl/TorchCommXCCL.hpp b/comms/torchcomms/xccl/TorchCommXCCL.hpp index 7914727c..7acbb71d 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.hpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.hpp @@ -177,8 +177,6 @@ class TorchCommXCCL : public TorchCommBackend, std::atomic comm_state_{ CommState::NORMAL}; // State of the communicator - void register_address(const AddressWithLen &addr); - void deregister_address(const Address &addr); onecclDataType_t getXcclDataType(const at::Tensor &tensor); std::shared_ptr createWork(xpuStream_t stream, std::chrono::milliseconds timeout, diff --git a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp index a96901e3..c3b67b72 100644 --- a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp @@ -314,14 +314,12 @@ void TorchCommXCCL::returnEvent(xpuEvent_t &&event) { } void TorchCommXCCL::attachMemoryHook() { - // NOTE: CachingAllocatorHook is not implemented for XPU/SYCL yet - // TODO: Implement XPU caching allocator hook when available + // NOTE: Currently, oneCCL doesn't support memory register and deregister // CachingAllocatorHook::getInstance().registerComm(this); } void TorchCommXCCL::detachMemoryHook() { - // NOTE: CachingAllocatorHook is not implemented for XPU/SYCL yet - // TODO: Implement XPU caching allocator hook when available + // NOTE: Currently, oneCCL doesn't support memory register and deregister // CachingAllocatorHook::getInstance().deregisterComm(this); } From 8b02e0ad7d29a72bd066849ea898cc5ca810c952 Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Thu, 6 Nov 2025 17:26:09 +0800 Subject: [PATCH 05/10] move some APIs if we cannot support now --- comms/torchcomms/device/XpuApi.cpp | 21 +------ comms/torchcomms/xccl/TorchCommXCCL.cpp | 10 ---- comms/torchcomms/xccl/TorchCommXCCL.hpp | 63 -------------------- comms/torchcomms/xccl/TorchCommXCCLUtils.cpp | 17 +----- comms/torchcomms/xccl/TorchWorkXCCL.hpp | 6 -- 5 files changed, 5 insertions(+), 112 deletions(-) diff --git a/comms/torchcomms/device/XpuApi.cpp b/comms/torchcomms/device/XpuApi.cpp index 1a5dc7c8..cdccd4ae 100644 --- a/comms/torchcomms/device/XpuApi.cpp +++ b/comms/torchcomms/device/XpuApi.cpp @@ -276,7 +276,7 @@ xpu_result_t DefaultXpuApi::graphRetainUserObject( xpuUserObject_t object, unsigned int count, unsigned int flags) { - // XPU/SYCL doesn't support graphs + // Currently, XPU/SYCL doesn't support graphs return XPU_ERROR_UNSUPPORTED; } @@ -287,23 +287,8 @@ xpu_result_t DefaultXpuApi::streamGetCaptureInfo_v2( xpuGraph_t* graph_out, const xpuGraphNode_t** dependencies_out, size_t* numDependencies_out) { - if (captureStatus_out) { - *captureStatus_out = xpuStreamCaptureStatusNone; - } - if (id_out) { - *id_out = 0; - } - if (graph_out) { - *graph_out = nullptr; - } - if (dependencies_out) { - *dependencies_out = nullptr; - } - if (numDependencies_out) { - *numDependencies_out = 0; - } - - return XPU_SUCCESS; + // Currently, XPU/SYCL doesn't support graphs + return XPU_ERROR_UNSUPPORTED; } // Error Handling diff --git a/comms/torchcomms/xccl/TorchCommXCCL.cpp b/comms/torchcomms/xccl/TorchCommXCCL.cpp index 06c8f1b1..e7f5a359 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.cpp @@ -31,10 +31,6 @@ TorchCommXCCL::~TorchCommXCCL() { timeout_thread_.join(); } } - - // We need to detach the memory hook in case finalize is not called, - // so that we don't encounter a memory corruption. - detachMemoryHook(); } void TorchCommXCCL::init(at::Device device, const std::string &name, @@ -152,9 +148,6 @@ void TorchCommXCCL::init(at::Device device, const std::string &name, // Start timeout watchdog thread timeout_thread_ = std::thread(&TorchCommXCCL::timeoutWatchdog, this); - - // Register comm with CachingAllocator - attachMemoryHook(); } void TorchCommXCCL::finalize() { @@ -237,15 +230,12 @@ void TorchCommXCCL::finalize() { // Destroy XCCL communicator // TODO: should probably not call this after calling abort. if (xccl_comm_) { - detachMemoryHook(); - // Deregister comm from the CachingAllocator xccl_api_->commDestroy(xccl_comm_); xccl_comm_ = nullptr; } } void TorchCommXCCL::abortXcclComm() { - detachMemoryHook(); if (xccl_comm_) { xccl_api_->commAbort(xccl_comm_); xccl_comm_ = nullptr; diff --git a/comms/torchcomms/xccl/TorchCommXCCL.hpp b/comms/torchcomms/xccl/TorchCommXCCL.hpp index 7acbb71d..487d4377 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.hpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.hpp @@ -133,9 +133,6 @@ class TorchCommXCCL : public TorchCommBackend, // Friend access for TorchCommXCCL friend class TorchWorkXCCL; - // NOTE: CachingAllocatorHook is not implemented for XPU/SYCL yet - // friend class CachingAllocatorHookImpl; - friend class TorchCommWindowXCCL; // Getter for XPU API (for friend classes) XpuApi *getXpuApi() const { return xpu_api_.get(); } @@ -165,15 +162,6 @@ class TorchCommXCCL : public TorchCommBackend, TIMEOUT, }; - struct Address { - void *addr; - }; - - struct AddressWithLen { - void *addr; - size_t len; - }; - std::atomic comm_state_{ CommState::NORMAL}; // State of the communicator @@ -206,24 +194,6 @@ class TorchCommXCCL : public TorchCommBackend, std::shared_ptr xccl_api_; }; - // Struct to hold the registration handle for a buffer - struct RegistrationHandle { - void *regHandle; - - explicit RegistrationHandle(void *regHandle) : regHandle{regHandle} {} - - RegistrationHandle(RegistrationHandle &&other) noexcept - : regHandle{other.regHandle} { - other.regHandle = nullptr; - } - - RegistrationHandle(const RegistrationHandle &) = delete; - RegistrationHandle &operator=(const RegistrationHandle &) = delete; - RegistrationHandle &operator=(RegistrationHandle &&) = delete; - - ~RegistrationHandle() = default; - }; - // Constructor for split communicators explicit TorchCommXCCL(const onecclComm_t xccl_comm); @@ -239,9 +209,6 @@ class TorchCommXCCL : public TorchCommBackend, xpuStream_t getOperationStream(bool async_op); void ensureTensorContiguous(const at::Tensor &tensor); - void attachMemoryHook(); - void detachMemoryHook(); - // Member variables onecclComm_t xccl_comm_{}; at::Device device_; @@ -259,10 +226,6 @@ class TorchCommXCCL : public TorchCommBackend, FINALIZED, } init_state_; - // List of [comm, regHandlesMap] pairs. Each regHandlesMap is a map from the - // buffer address to the registeration handle - std::map memoryRegistrationHandles_; - // XCCL API abstraction std::shared_ptr xccl_api_; @@ -285,32 +248,6 @@ class TorchCommXCCL : public TorchCommBackend, std::shared_ptr tracing_; bool high_priority_stream_{false}; std::string name_; - - // Graph capture mode work references - // Keep references to work objects during graph capture to prevent premature - // destruction, organized per graph using capture ID - std::unordered_map>> - graph_capture_work_refs_; - std::mutex graph_capture_work_mutex_; - - // Structure to hold cleanup data for XPU user objects - // NOTE: Graph capture cleanup is currently disabled for XPU/SYCL - // as the required APIs (userObjectCreate, graphRetainUserObject) are not yet - // available - struct GraphCleanupData { - TorchCommXCCL *comm; - unsigned long long graph_id; - - GraphCleanupData(TorchCommXCCL *comm_, unsigned long long id) - : comm(comm_), graph_id(id) {} - }; - - // Static callback function for XPU user object cleanup - // NOTE: Currently disabled - XPU/SYCL does not have equivalent callback - // mechanism static void graphCleanupCallback(void* userData); - - friend class TorchWorkXCCLQueueCommTest; }; } // namespace comms diff --git a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp index c3b67b72..d338ae2c 100644 --- a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp @@ -1,6 +1,4 @@ #include "comms/torchcomms/xccl/TorchCommXCCL.hpp" -// #include "comms/torchcomms/xccl/TorchCommXCCLCCA.hpp" - #include "comms/torchcomms/TorchCommLogging.hpp" #include #include @@ -219,7 +217,7 @@ void TorchCommXCCL::checkAndAbortIfTimedOutOrError() { checkWorkQueue(true); if (comm_state_ == CommState::TIMEOUT) { - abortXcclComm(); +// abortXcclComm(); // cannot abort oneCCL communicator if (options_.abort_process_on_timeout_or_error) { TC_LOG(ERROR) << "Aborting process due to timeout"; abort(); @@ -230,7 +228,7 @@ void TorchCommXCCL::checkAndAbortIfTimedOutOrError() { onecclResult_t asyncErr; xccl_api_->commGetAsyncError(xccl_comm_, &asyncErr); XCCLException xcclException(*xccl_api_, "XCCL Async Error", asyncErr); - abortXcclComm(); +// abortXcclComm(); // cannot abort oneCCL communicator if (options_.abort_process_on_timeout_or_error) { TC_LOG(ERROR) << "Aborting process due to error: " << xcclException.what(); @@ -312,16 +310,5 @@ void TorchCommXCCL::returnEvent(xpuEvent_t &&event) { "Failed to destroy event"); } } - -void TorchCommXCCL::attachMemoryHook() { - // NOTE: Currently, oneCCL doesn't support memory register and deregister - // CachingAllocatorHook::getInstance().registerComm(this); -} - -void TorchCommXCCL::detachMemoryHook() { - // NOTE: Currently, oneCCL doesn't support memory register and deregister - // CachingAllocatorHook::getInstance().deregisterComm(this); -} - } // namespace comms } // namespace torch diff --git a/comms/torchcomms/xccl/TorchWorkXCCL.hpp b/comms/torchcomms/xccl/TorchWorkXCCL.hpp index c51f04d8..ed60c604 100644 --- a/comms/torchcomms/xccl/TorchWorkXCCL.hpp +++ b/comms/torchcomms/xccl/TorchWorkXCCL.hpp @@ -18,12 +18,6 @@ namespace comms { // Forward declaration class TorchCommXCCL; -class TorchCommWindowXCCL; - -// Forward declaration for test class -namespace test { -class TorchCommXCCLTest; -} class TorchWorkXCCL : public TorchWork { public: From 96bfb6bfb019672d635ee301ff1e279fd64c330d Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 10 Nov 2025 09:35:50 +0800 Subject: [PATCH 06/10] fix ptr --- comms/torchcomms/xccl/TorchCommXCCL.cpp | 53 +++++++++------ comms/torchcomms/xccl/TorchCommXCCL.hpp | 64 +++++++++++-------- .../xccl/TorchCommXCCLBootstrap.hpp | 3 +- comms/torchcomms/xccl/TorchCommXCCLUtils.cpp | 14 ++-- comms/torchcomms/xccl/TorchWorkXCCL.hpp | 8 +-- comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp | 2 +- 6 files changed, 84 insertions(+), 60 deletions(-) diff --git a/comms/torchcomms/xccl/TorchCommXCCL.cpp b/comms/torchcomms/xccl/TorchCommXCCL.cpp index e7f5a359..75d3a9fe 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.cpp @@ -283,7 +283,7 @@ getOperationTimeout(std::chrono::milliseconds timeout, } // Point-to-Point Operations -std::shared_ptr TorchCommXCCL::send(const at::Tensor &tensor, +c10::intrusive_ptr TorchCommXCCL::send(const at::Tensor &tensor, int dst, bool async_op, const SendOptions &options) { checkInitialized(); @@ -313,7 +313,7 @@ std::shared_ptr TorchCommXCCL::send(const at::Tensor &tensor, return work; } -std::shared_ptr TorchCommXCCL::recv(at::Tensor &tensor, int src, +c10::intrusive_ptr TorchCommXCCL::recv(at::Tensor &tensor, int src, bool async_op, const RecvOptions &options) { checkInitialized(); @@ -344,7 +344,7 @@ std::shared_ptr TorchCommXCCL::recv(at::Tensor &tensor, int src, } // Batch P2P Operations -std::shared_ptr +c10::intrusive_ptr TorchCommXCCL::batch_op_issue(const std::vector &ops, bool async_op, const BatchP2POptions &options) { checkInitialized(); @@ -426,7 +426,7 @@ TorchCommXCCL::batch_op_issue(const std::vector &ops, } // Collective Operations -std::shared_ptr +c10::intrusive_ptr TorchCommXCCL::broadcast(at::Tensor &tensor, int root, bool async_op, const BroadcastOptions &options) { checkInitialized(); @@ -457,8 +457,8 @@ TorchCommXCCL::broadcast(at::Tensor &tensor, int root, bool async_op, return work; } -std::shared_ptr -TorchCommXCCL::all_reduce(at::Tensor &tensor, ReduceOp op, bool async_op, +c10::intrusive_ptr +TorchCommXCCL::all_reduce(at::Tensor &tensor, const ReduceOp &op, bool async_op, const AllReduceOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); @@ -490,8 +490,8 @@ TorchCommXCCL::all_reduce(at::Tensor &tensor, ReduceOp op, bool async_op, return work; } -std::shared_ptr TorchCommXCCL::reduce(const at::Tensor &tensor, - int root, ReduceOp op, +c10::intrusive_ptr TorchCommXCCL::reduce(const at::Tensor &tensor, + int root, const ReduceOp &op, bool async_op, const ReduceOptions &options) { checkInitialized(); @@ -528,7 +528,7 @@ std::shared_ptr TorchCommXCCL::reduce(const at::Tensor &tensor, return work; } -std::shared_ptr +c10::intrusive_ptr TorchCommXCCL::all_gather(const std::vector &tensor_list, const at::Tensor &tensor, bool async_op, const AllGatherOptions &options) { @@ -575,7 +575,14 @@ TorchCommXCCL::all_gather(const std::vector &tensor_list, return work; } -std::shared_ptr +c10::intrusive_ptr +TorchCommXCCL::all_gather_v(const std::vector &tensor_list, + const at::Tensor &tensor, bool async_op, + const AllGatherOptions &options) { + throw std::runtime_error("all_gather_v is not supported in XCCL backend"); +} + +c10::intrusive_ptr TorchCommXCCL::all_gather_single(at::Tensor &output, const at::Tensor &input, bool async_op, const AllGatherSingleOptions &options) { @@ -613,8 +620,8 @@ TorchCommXCCL::all_gather_single(at::Tensor &output, const at::Tensor &input, return work; } -std::shared_ptr TorchCommXCCL::reduce_scatter( - at::Tensor &output, const std::vector &input_list, ReduceOp op, +c10::intrusive_ptr TorchCommXCCL::reduce_scatter( + at::Tensor &output, const std::vector &input_list, const ReduceOp &op, bool async_op, const ReduceScatterOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); @@ -665,8 +672,14 @@ std::shared_ptr TorchCommXCCL::reduce_scatter( return work; } -std::shared_ptr TorchCommXCCL::reduce_scatter_single( - at::Tensor &output, const at::Tensor &input, ReduceOp op, bool async_op, +c10::intrusive_ptr TorchCommXCCL::reduce_scatter_v( + at::Tensor &output, const std::vector &input_list, + const ReduceOp &op, bool async_op, const ReduceScatterOptions &options) { + throw std::runtime_error("reduce_scatter_v is not supported in XCCL backend"); +} + +c10::intrusive_ptr TorchCommXCCL::reduce_scatter_single( + at::Tensor &output, const at::Tensor &input, const ReduceOp &op, bool async_op, const ReduceScatterSingleOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); @@ -703,7 +716,7 @@ std::shared_ptr TorchCommXCCL::reduce_scatter_single( return work; } -std::shared_ptr +c10::intrusive_ptr TorchCommXCCL::all_to_all_single(at::Tensor &output, const at::Tensor &input, bool async_op, const AllToAllSingleOptions &options) { @@ -748,7 +761,7 @@ TorchCommXCCL::all_to_all_single(at::Tensor &output, const at::Tensor &input, return work; } -std::shared_ptr TorchCommXCCL::all_to_all_v_single( +c10::intrusive_ptr TorchCommXCCL::all_to_all_v_single( at::Tensor &output, const at::Tensor &input, const std::vector &output_split_sizes, const std::vector &input_split_sizes, bool async_op, @@ -823,7 +836,7 @@ std::shared_ptr TorchCommXCCL::all_to_all_v_single( return work; } -std::shared_ptr +c10::intrusive_ptr TorchCommXCCL::all_to_all(const std::vector &output_tensor_list, const std::vector &input_tensor_list, bool async_op, const AllToAllOptions &options) { @@ -874,7 +887,7 @@ TorchCommXCCL::all_to_all(const std::vector &output_tensor_list, return work; } -std::shared_ptr +c10::intrusive_ptr TorchCommXCCL::barrier(bool async_op, const BarrierOptions &options) { checkInitialized(); checkAndAbortIfTimedOutOrError(); @@ -902,7 +915,7 @@ TorchCommXCCL::barrier(bool async_op, const BarrierOptions &options) { return work; } -std::shared_ptr +c10::intrusive_ptr TorchCommXCCL::scatter(at::Tensor &output_tensor, const std::vector &input_tensor_list, int root, bool async_op, const ScatterOptions &options) { @@ -974,7 +987,7 @@ TorchCommXCCL::scatter(at::Tensor &output_tensor, return work; } -std::shared_ptr +c10::intrusive_ptr TorchCommXCCL::gather(const std::vector &output_tensor_list, const at::Tensor &input_tensor, int root, bool async_op, const GatherOptions &options) { diff --git a/comms/torchcomms/xccl/TorchCommXCCL.hpp b/comms/torchcomms/xccl/TorchCommXCCL.hpp index 487d4377..a86a0d59 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.hpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.hpp @@ -65,63 +65,71 @@ class TorchCommXCCL : public TorchCommBackend, std::string_view getCommName() const override; // Point-to-Point Operations - std::shared_ptr send(const at::Tensor &tensor, int dst, - bool async_op, - const SendOptions &options = {}) override; - std::shared_ptr recv(at::Tensor &tensor, int src, bool async_op, - const RecvOptions &options = {}) override; + c10::intrusive_ptr send(const at::Tensor &tensor, int dst, + bool async_op, + const SendOptions &options = {}) override; + c10::intrusive_ptr recv(at::Tensor &tensor, int src, bool async_op, + const RecvOptions &options = {}) override; // Batch P2P Operations - std::shared_ptr + c10::intrusive_ptr batch_op_issue(const std::vector &ops, bool async_op, const BatchP2POptions &options = {}) override; // Collective Operations - std::shared_ptr + c10::intrusive_ptr broadcast(at::Tensor &tensor, int root, bool async_op, const BroadcastOptions &options = {}) override; - std::shared_ptr - all_reduce(at::Tensor &tensor, ReduceOp op, bool async_op, + c10::intrusive_ptr + all_reduce(at::Tensor &tensor, const ReduceOp &op, bool async_op, const AllReduceOptions &options = {}) override; - std::shared_ptr reduce(const at::Tensor &tensor, int root, - ReduceOp op, bool async_op, - const ReduceOptions &options = {}) override; - std::shared_ptr + c10::intrusive_ptr + reduce(const at::Tensor &tensor, int root, const ReduceOp &op, bool async_op, + const ReduceOptions &options = {}) override; + c10::intrusive_ptr all_gather(const std::vector &tensor_list, const at::Tensor &tensor, bool async_op, const AllGatherOptions &options = {}) override; - std::shared_ptr + c10::intrusive_ptr + all_gather_v(const std::vector &tensor_list, + const at::Tensor &tensor, bool async_op, + const AllGatherOptions &options = {}) override; + c10::intrusive_ptr all_gather_single(at::Tensor &output, const at::Tensor &input, bool async_op, const AllGatherSingleOptions &options = {}) override; - std::shared_ptr + c10::intrusive_ptr reduce_scatter(at::Tensor &output, const std::vector &input_list, - ReduceOp op, bool async_op, + const ReduceOp &op, bool async_op, const ReduceScatterOptions &options = {}) override; - std::shared_ptr reduce_scatter_single( - at::Tensor &output, const at::Tensor &input, ReduceOp op, bool async_op, + c10::intrusive_ptr + reduce_scatter_v(at::Tensor &output, const std::vector &input_list, + const ReduceOp &op, bool async_op, + const ReduceScatterOptions &options = {}) override; + c10::intrusive_ptr reduce_scatter_single( + at::Tensor &output, const at::Tensor &input, const ReduceOp &op, bool async_op, const ReduceScatterSingleOptions &options = {}) override; - std::shared_ptr + c10::intrusive_ptr all_to_all_single(at::Tensor &output, const at::Tensor &input, bool async_op, const AllToAllSingleOptions &options = {}) override; - std::shared_ptr + c10::intrusive_ptr all_to_all_v_single(at::Tensor &output, const at::Tensor &input, const std::vector &output_split_sizes, const std::vector &input_split_sizes, bool async_op, const AllToAllvSingleOptions &options = {}) override; - std::shared_ptr + c10::intrusive_ptr all_to_all(const std::vector &output_tensor_list, const std::vector &input_tensor_list, bool async_op, const AllToAllOptions &options = {}) override; - std::shared_ptr + c10::intrusive_ptr barrier(bool async_op, const BarrierOptions &options = {}) override; // Scatter and Gather Operations - std::shared_ptr + c10::intrusive_ptr scatter(at::Tensor &output_tensor, const std::vector &input_tensor_list, int root, bool async_op, const ScatterOptions &options = {}) override; - std::shared_ptr + c10::intrusive_ptr gather(const std::vector &output_tensor_list, const at::Tensor &input_tensor, int root, bool async_op, const GatherOptions &options = {}) override; @@ -141,7 +149,9 @@ class TorchCommXCCL : public TorchCommBackend, XcclApi *getXcclApi() const { return xccl_api_.get(); } // Method to override the XCCL API implementation for testing - void setXcclApi(std::shared_ptr api) { xccl_api_ = std::move(api); } + void setXcclApi(std::shared_ptr api) { + xccl_api_ = std::move(api); + } // Method to override the XPU API implementation for testing void setXpuApi(std::shared_ptr api) { xpu_api_ = std::move(api); } @@ -166,7 +176,7 @@ class TorchCommXCCL : public TorchCommBackend, CommState::NORMAL}; // State of the communicator onecclDataType_t getXcclDataType(const at::Tensor &tensor); - std::shared_ptr + c10::intrusive_ptr createWork(xpuStream_t stream, std::chrono::milliseconds timeout, const std::vector &inputTensors); @@ -205,7 +215,7 @@ class TorchCommXCCL : public TorchCommBackend, void checkInitialized() const; void checkAndAbortIfTimedOutOrError(); void checkWorkQueue(bool isMainThread); - void enqueueWork(std::shared_ptr work, xpuStream_t stream); + void enqueueWork(c10::intrusive_ptr work, xpuStream_t stream); xpuStream_t getOperationStream(bool async_op); void ensureTensorContiguous(const at::Tensor &tensor); diff --git a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp index 6b43f555..ea677d4f 100644 --- a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp +++ b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp @@ -19,7 +19,8 @@ constexpr uint16_t kTCPStorePort = 29500; class TorchCommXCCLBootstrap { public: TorchCommXCCLBootstrap(c10::intrusive_ptr store, - c10::Device device, std::shared_ptr xccl_api, + c10::Device device, + std::shared_ptr xccl_api, std::shared_ptr xpu_api, std::chrono::milliseconds timeout); ~TorchCommXCCLBootstrap(); diff --git a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp index d338ae2c..da754fb2 100644 --- a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp @@ -1,5 +1,5 @@ -#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" #include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" #include #include #include @@ -217,7 +217,7 @@ void TorchCommXCCL::checkAndAbortIfTimedOutOrError() { checkWorkQueue(true); if (comm_state_ == CommState::TIMEOUT) { -// abortXcclComm(); // cannot abort oneCCL communicator + // abortXcclComm(); // cannot abort oneCCL communicator if (options_.abort_process_on_timeout_or_error) { TC_LOG(ERROR) << "Aborting process due to timeout"; abort(); @@ -228,7 +228,7 @@ void TorchCommXCCL::checkAndAbortIfTimedOutOrError() { onecclResult_t asyncErr; xccl_api_->commGetAsyncError(xccl_comm_, &asyncErr); XCCLException xcclException(*xccl_api_, "XCCL Async Error", asyncErr); -// abortXcclComm(); // cannot abort oneCCL communicator + // abortXcclComm(); // cannot abort oneCCL communicator if (options_.abort_process_on_timeout_or_error) { TC_LOG(ERROR) << "Aborting process due to error: " << xcclException.what(); @@ -239,16 +239,16 @@ void TorchCommXCCL::checkAndAbortIfTimedOutOrError() { } } -std::shared_ptr +c10::intrusive_ptr TorchCommXCCL::createWork(xpuStream_t stream, std::chrono::milliseconds timeout, const std::vector &inputTensors) { // Only create the work object without enqueuing it - auto work = std::make_shared(shared_from_this(), stream, - timeout, inputTensors, tracing_); + auto work = c10::make_intrusive(shared_from_this(), stream, + timeout, inputTensors, tracing_); return work; } -void TorchCommXCCL::enqueueWork(std::shared_ptr work, +void TorchCommXCCL::enqueueWork(c10::intrusive_ptr work, xpuStream_t stream) { // Add work to stream's queue after events have been recorded workq_.enqueueWork(std::move(work), stream); diff --git a/comms/torchcomms/xccl/TorchWorkXCCL.hpp b/comms/torchcomms/xccl/TorchWorkXCCL.hpp index ed60c604..d81e557e 100644 --- a/comms/torchcomms/xccl/TorchWorkXCCL.hpp +++ b/comms/torchcomms/xccl/TorchWorkXCCL.hpp @@ -8,10 +8,10 @@ #include #include -#include #include "comms/torchcomms/TorchCommTracing.hpp" #include "comms/torchcomms/TorchWork.hpp" #include "comms/torchcomms/device/XpuApi.hpp" +#include namespace torch { namespace comms { @@ -82,12 +82,12 @@ class TorchWorkXCCLQueue { TorchWorkXCCL::WorkStatus garbageCollect(bool isMainThread); // Finalize function can only be called from the main thread TorchWorkXCCL::WorkStatus finalize(); - void enqueueWork(std::shared_ptr work, xpuStream_t stream); + void enqueueWork(c10::intrusive_ptr work, xpuStream_t stream); private: - std::unordered_map>> + std::unordered_map>> stream_work_queues_; - std::vector> completed_work_queue_; + std::vector> completed_work_queue_; std::recursive_mutex work_queues_mutex_; }; diff --git a/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp b/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp index 20d94d22..e4edecbc 100644 --- a/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp +++ b/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp @@ -85,7 +85,7 @@ TorchWorkXCCL::WorkStatus TorchWorkXCCLQueue::finalize() { return status; } -void TorchWorkXCCLQueue::enqueueWork(std::shared_ptr work, +void TorchWorkXCCLQueue::enqueueWork(c10::intrusive_ptr work, xpuStream_t stream) { // Add work to stream's queue after events have been recorded std::lock_guard lock(work_queues_mutex_); From 2ce3df27b98d52f4d51e37b5359effa6b84da299 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Fri, 7 Nov 2025 10:27:40 +0800 Subject: [PATCH 07/10] support xccl test --- .../tests/integration/py/AllGatherSingleTest.py | 2 ++ comms/torchcomms/tests/integration/py/AllGatherTest.py | 2 ++ comms/torchcomms/tests/integration/py/AllGatherVTest.py | 2 ++ comms/torchcomms/tests/integration/py/AllReduceTest.py | 3 +++ .../torchcomms/tests/integration/py/AllToAllSingleTest.py | 2 ++ comms/torchcomms/tests/integration/py/AllToAllTest.py | 2 ++ comms/torchcomms/tests/integration/py/BroadcastTest.py | 2 ++ comms/torchcomms/tests/integration/py/GatherTest.py | 2 ++ comms/torchcomms/tests/integration/py/ObjColTest.py | 3 +++ .../tests/integration/py/ReduceScatterSingleTest.py | 3 +++ .../torchcomms/tests/integration/py/ReduceScatterTest.py | 3 +++ comms/torchcomms/tests/integration/py/ReduceTest.py | 3 +++ comms/torchcomms/tests/integration/py/ScatterTest.py | 2 ++ .../tests/integration/py/TorchCommTestHelpers.py | 8 +++++--- 14 files changed, 36 insertions(+), 3 deletions(-) diff --git a/comms/torchcomms/tests/integration/py/AllGatherSingleTest.py b/comms/torchcomms/tests/integration/py/AllGatherSingleTest.py index 35e26cc1..fe646c8f 100644 --- a/comms/torchcomms/tests/integration/py/AllGatherSingleTest.py +++ b/comms/torchcomms/tests/integration/py/AllGatherSingleTest.py @@ -18,6 +18,8 @@ class AllGatherSingleTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/AllGatherTest.py b/comms/torchcomms/tests/integration/py/AllGatherTest.py index 8c76aeec..01f4d5f4 100644 --- a/comms/torchcomms/tests/integration/py/AllGatherTest.py +++ b/comms/torchcomms/tests/integration/py/AllGatherTest.py @@ -18,6 +18,8 @@ class AllGatherTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/AllGatherVTest.py b/comms/torchcomms/tests/integration/py/AllGatherVTest.py index f971144d..eda57a38 100644 --- a/comms/torchcomms/tests/integration/py/AllGatherVTest.py +++ b/comms/torchcomms/tests/integration/py/AllGatherVTest.py @@ -19,6 +19,8 @@ class AllGatherVTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] num_replays = 4 def get_wrapper(self): diff --git a/comms/torchcomms/tests/integration/py/AllReduceTest.py b/comms/torchcomms/tests/integration/py/AllReduceTest.py index 81188429..6bb2ce03 100644 --- a/comms/torchcomms/tests/integration/py/AllReduceTest.py +++ b/comms/torchcomms/tests/integration/py/AllReduceTest.py @@ -21,6 +21,9 @@ class AllReduceTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] + dtypes = [torch.float, torch.int] ops = [ReduceOp.SUM, ReduceOp.MAX] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/AllToAllSingleTest.py b/comms/torchcomms/tests/integration/py/AllToAllSingleTest.py index 7c4934d1..0be58374 100644 --- a/comms/torchcomms/tests/integration/py/AllToAllSingleTest.py +++ b/comms/torchcomms/tests/integration/py/AllToAllSingleTest.py @@ -18,6 +18,8 @@ class AllToAllSingleTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/AllToAllTest.py b/comms/torchcomms/tests/integration/py/AllToAllTest.py index 4a8b73eb..11daa5de 100644 --- a/comms/torchcomms/tests/integration/py/AllToAllTest.py +++ b/comms/torchcomms/tests/integration/py/AllToAllTest.py @@ -18,6 +18,8 @@ class AllToAllTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/BroadcastTest.py b/comms/torchcomms/tests/integration/py/BroadcastTest.py index 8a21de03..290767d4 100644 --- a/comms/torchcomms/tests/integration/py/BroadcastTest.py +++ b/comms/torchcomms/tests/integration/py/BroadcastTest.py @@ -18,6 +18,8 @@ class BroadcastTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/GatherTest.py b/comms/torchcomms/tests/integration/py/GatherTest.py index 496397f7..158596bb 100644 --- a/comms/torchcomms/tests/integration/py/GatherTest.py +++ b/comms/torchcomms/tests/integration/py/GatherTest.py @@ -18,6 +18,8 @@ class GatherTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ObjColTest.py b/comms/torchcomms/tests/integration/py/ObjColTest.py index 28e6e897..30028874 100644 --- a/comms/torchcomms/tests/integration/py/ObjColTest.py +++ b/comms/torchcomms/tests/integration/py/ObjColTest.py @@ -2,6 +2,7 @@ # pyre-unsafe # Copyright (c) Meta Platforms, Inc. and affiliates. +import os import unittest from contextlib import contextmanager @@ -29,6 +30,8 @@ class ObjColTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ReduceScatterSingleTest.py b/comms/torchcomms/tests/integration/py/ReduceScatterSingleTest.py index 508f50e6..027360f2 100644 --- a/comms/torchcomms/tests/integration/py/ReduceScatterSingleTest.py +++ b/comms/torchcomms/tests/integration/py/ReduceScatterSingleTest.py @@ -21,6 +21,9 @@ class ReduceScatterSingleTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] + dtypes = [torch.float, torch.int] ops = [ReduceOp.SUM, ReduceOp.MAX] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ReduceScatterTest.py b/comms/torchcomms/tests/integration/py/ReduceScatterTest.py index f448a80a..19831a76 100644 --- a/comms/torchcomms/tests/integration/py/ReduceScatterTest.py +++ b/comms/torchcomms/tests/integration/py/ReduceScatterTest.py @@ -21,6 +21,9 @@ class ReduceScatterTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] + dtypes = [torch.float, torch.int] ops = [ReduceOp.SUM, ReduceOp.MAX] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ReduceTest.py b/comms/torchcomms/tests/integration/py/ReduceTest.py index c593c3e5..339d4172 100644 --- a/comms/torchcomms/tests/integration/py/ReduceTest.py +++ b/comms/torchcomms/tests/integration/py/ReduceTest.py @@ -21,6 +21,9 @@ class ReduceTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] + dtypes = [torch.float, torch.int] ops = [ReduceOp.SUM, ReduceOp.MAX] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ScatterTest.py b/comms/torchcomms/tests/integration/py/ScatterTest.py index e606675d..70a629ea 100644 --- a/comms/torchcomms/tests/integration/py/ScatterTest.py +++ b/comms/torchcomms/tests/integration/py/ScatterTest.py @@ -18,6 +18,8 @@ class ScatterTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] + if os.environ.get("TEST_BACKEND") == "xccl": + counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/TorchCommTestHelpers.py b/comms/torchcomms/tests/integration/py/TorchCommTestHelpers.py index b6a955c2..44da2f05 100644 --- a/comms/torchcomms/tests/integration/py/TorchCommTestHelpers.py +++ b/comms/torchcomms/tests/integration/py/TorchCommTestHelpers.py @@ -218,11 +218,13 @@ def get_device(self, backend, rank): if device := os.environ.get("TEST_DEVICE"): return torch.device(device) - # Check for CUDA availability and abort if not available - if not torch.cuda.is_available(): + # Check for accelerator availability and abort if not available + if not torch.accelerator.is_available(): return torch.device("cpu") - device_id = rank % torch.cuda.device_count() + device_id = rank % torch.accelerator.device_count() + if os.getenv("TEST_BACKEND") == "xccl": + return torch.device(f"xpu:{device_id}") return torch.device(f"cuda:{device_id}") def __init__(self, store=None): From a16198dabee542c8c7fde9c1d7f97dc98672ef4e Mon Sep 17 00:00:00 2001 From: lzhang2 Date: Mon, 10 Nov 2025 10:52:47 +0800 Subject: [PATCH 08/10] support allreduce --- comms/torchcomms/xccl/TorchCommXCCL.cpp | 659 +----------------------- 1 file changed, 16 insertions(+), 643 deletions(-) diff --git a/comms/torchcomms/xccl/TorchCommXCCL.cpp b/comms/torchcomms/xccl/TorchCommXCCL.cpp index 75d3a9fe..b0c7bf82 100644 --- a/comms/torchcomms/xccl/TorchCommXCCL.cpp +++ b/comms/torchcomms/xccl/TorchCommXCCL.cpp @@ -286,175 +286,27 @@ getOperationTimeout(std::chrono::milliseconds timeout, c10::intrusive_ptr TorchCommXCCL::send(const at::Tensor &tensor, int dst, bool async_op, const SendOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(tensor); - - tracing_->recordEventWithInputOutput("send", dst, {tensor}, {tensor}); - - xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); - - work->recordStart(); - - onecclResult_t result = - xccl_api_->send(tensor.data_ptr(), tensor.numel(), - getXcclDataType(tensor), dst, xccl_comm_, stream); - - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL Send failed", result); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL send is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::recv(at::Tensor &tensor, int src, bool async_op, const RecvOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(tensor); - - tracing_->recordEventWithInputOutput("recv", src, {tensor}, {tensor}); - - xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {}); - - work->recordStart(); - - onecclResult_t result = - xccl_api_->recv(tensor.data_ptr(), tensor.numel(), - getXcclDataType(tensor), src, xccl_comm_, stream); - - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL Recv failed", result); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL recv is not supported now and will be added later"); } // Batch P2P Operations c10::intrusive_ptr TorchCommXCCL::batch_op_issue(const std::vector &ops, bool async_op, const BatchP2POptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - - if (ops.empty()) { - throw std::runtime_error("Cannot issue empty batch operation"); - } - - // Collect input and output tensors for work tracking - std::vector input_tensors; - std::vector output_tensors; - - for (const auto &op : ops) { - if (op.type == BatchSendRecv::P2POp::OpType::SEND) { - at::Tensor tensor = op.tensor; - ensureTensorContiguous(tensor); - input_tensors.push_back(tensor); - } else if (op.type == BatchSendRecv::P2POp::OpType::RECV) { - at::Tensor tensor = op.tensor; - ensureTensorContiguous(tensor); - output_tensors.push_back(tensor); - } else { - throw std::runtime_error("Unknown op type"); - } - } - - tracing_->recordEventWithInputOutput("batch_op_issue", rank_, input_tensors, - output_tensors); - - xpuStream_t stream = getOperationStream(async_op); - auto work = - createWork(stream, getOperationTimeout(options.timeout, options_.timeout), - input_tensors); - - work->recordStart(); - - // Start XCCL group for batched operations - onecclResult_t result = xccl_api_->groupStart(); - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL GroupStart failed", result); - } - - // Issue each operation individually - for (const auto &op : ops) { - if (op.type == BatchSendRecv::P2POp::OpType::SEND) { - result = xccl_api_->send(op.tensor.data_ptr(), op.tensor.numel(), - getXcclDataType(op.tensor), op.peer, xccl_comm_, - stream); - - if (result != onecclSuccess) { - xccl_api_->groupEnd(); // Clean up group on error - throw XCCLException(*xccl_api_, "XCCL Send failed in batch operation", - result); - } - } else if (op.type == BatchSendRecv::P2POp::OpType::RECV) { - result = xccl_api_->recv(op.tensor.data_ptr(), op.tensor.numel(), - getXcclDataType(op.tensor), op.peer, xccl_comm_, - stream); - - if (result != onecclSuccess) { - xccl_api_->groupEnd(); // Clean up group on error - throw XCCLException(*xccl_api_, "XCCL Recv failed in batch operation", - result); - } - } - } - - result = xccl_api_->groupEnd(); - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL GroupEnd failed", result); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL batch_op_issue is not supported now and will be added later"); } // Collective Operations c10::intrusive_ptr TorchCommXCCL::broadcast(at::Tensor &tensor, int root, bool async_op, const BroadcastOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(tensor); - - tracing_->recordEventWithInputOutput("broadcast", root, {tensor}, {tensor}); - - xpuStream_t stream = getOperationStream(async_op); - - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); - - work->recordStart(); - - onecclResult_t result = - xccl_api_->bcast(tensor.data_ptr(), tensor.numel(), - getXcclDataType(tensor), root, xccl_comm_, stream); - - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL Broadcast failed", result); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL broadcast is not supported now and will be added later"); } c10::intrusive_ptr @@ -494,85 +346,14 @@ c10::intrusive_ptr TorchCommXCCL::reduce(const at::Tensor &tensor, int root, const ReduceOp &op, bool async_op, const ReduceOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(tensor); - - tracing_->recordEventWithInputOutput("reduce", root, {tensor}, {tensor}); - - xpuStream_t stream = getOperationStream(async_op); - std::vector output_tensors; - if (rank_ == root) { - output_tensors.push_back(tensor); - } - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); - - work->recordStart(); - - const auto dataType = getXcclDataType(tensor); - onecclResult_t result = xccl_api_->reduce( - tensor.data_ptr(), - tensor.data_ptr(), // Use same buffer for all ranks - tensor.numel(), dataType, getXcclReduceOp(op, xccl_comm_, dataType), root, - xccl_comm_, stream); - - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL Reduce failed", result); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL reduce is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::all_gather(const std::vector &tensor_list, const at::Tensor &tensor, bool async_op, const AllGatherOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - if (tensor_list.size() != static_cast(comm_size_)) { - throw std::runtime_error( - "tensor_list size must equal comm_size for all_gather"); - } - - ensureTensorContiguous(tensor); - - for (const auto &t : tensor_list) { - ensureTensorContiguous(t); - if (t.numel() != tensor.numel()) { - throw std::runtime_error( - "All tensors in tensor_list must have same size as input tensor"); - } - } - - tracing_->recordEventWithInputOutput("all_gather", rank_, tensor_list, - {tensor}); - - xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); - - work->recordStart(); - - xccl_api_->groupStart(); - - for (int i = 0; i < comm_size_; ++i) { - xccl_api_->broadcast(tensor.data_ptr(), tensor_list[i].data_ptr(), - tensor.numel(), getXcclDataType(tensor_list[i]), i, - xccl_comm_, stream); - } - - xccl_api_->groupEnd(); - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL all_gather is not supported now and will be added later"); } c10::intrusive_ptr @@ -586,90 +367,13 @@ c10::intrusive_ptr TorchCommXCCL::all_gather_single(at::Tensor &output, const at::Tensor &input, bool async_op, const AllGatherSingleOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(output); - ensureTensorContiguous(input); - - if (output.numel() != input.numel() * comm_size_) { - throw std::runtime_error("Output tensor size must be input_size * " - "comm_size for all_gather_single"); - } - - tracing_->recordEventWithInputOutput("all_gather_single", rank_, {input}, - {output}); - - xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {input}); - - work->recordStart(); - - onecclResult_t result = - xccl_api_->allGather(input.data_ptr(), output.data_ptr(), input.numel(), - getXcclDataType(input), xccl_comm_, stream); - - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL AllGather failed", result); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL all_gather_single is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::reduce_scatter( at::Tensor &output, const std::vector &input_list, const ReduceOp &op, bool async_op, const ReduceScatterOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(output); - - if (input_list.size() != static_cast(comm_size_)) { - throw std::runtime_error( - "input_list size must equal comm_size for reduce_scatter"); - } - - // Check that all input tensors are contiguous and have correct size - for (const auto &t : input_list) { - ensureTensorContiguous(t); - if (t.numel() != output.numel()) { - throw std::runtime_error( - "All input tensors must have same size as output tensor"); - } - } - - tracing_->recordEventWithInputOutput("reduce_scatter", rank_, input_list, - {output}); - - xpuStream_t stream = getOperationStream(async_op); - auto work = - createWork(stream, getOperationTimeout(options.timeout, options_.timeout), - input_list); - - work->recordStart(); - - // Use multiple reduce operations for reduce_scatter - xccl_api_->groupStart(); - - for (int i = 0; i < comm_size_; ++i) { - const auto dataType = getXcclDataType(input_list[i]); - xccl_api_->reduce(input_list[i].data_ptr(), - i == rank_ ? output.data_ptr() : input_list[i].data_ptr(), - i == rank_ ? output.numel() : input_list[i].numel(), - dataType, getXcclReduceOp(op, xccl_comm_, dataType), i, - xccl_comm_, stream); - } - - xccl_api_->groupEnd(); - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL reduce_scatter is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::reduce_scatter_v( @@ -681,84 +385,14 @@ c10::intrusive_ptr TorchCommXCCL::reduce_scatter_v( c10::intrusive_ptr TorchCommXCCL::reduce_scatter_single( at::Tensor &output, const at::Tensor &input, const ReduceOp &op, bool async_op, const ReduceScatterSingleOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(output); - ensureTensorContiguous(input); - - if (input.numel() != output.numel() * comm_size_) { - throw std::runtime_error("Input tensor size must be output_size * " - "comm_size for reduce_scatter_single"); - } - - tracing_->recordEventWithInputOutput("reduce_scatter_single", rank_, {input}, - {output}); - - xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {input}); - - work->recordStart(); - - const auto dataType = getXcclDataType(input); - onecclResult_t result = xccl_api_->reduceScatter( - input.data_ptr(), output.data_ptr(), output.numel(), dataType, - getXcclReduceOp(op, xccl_comm_, dataType), xccl_comm_, stream); - - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL ReduceScatter failed", result); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL reduce_scatter_single is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::all_to_all_single(at::Tensor &output, const at::Tensor &input, bool async_op, const AllToAllSingleOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(output); - ensureTensorContiguous(input); - - if (input.numel() != output.numel()) { - throw std::runtime_error( - "Input and output tensors must have same size for all_to_all_single"); - } - - if (input.numel() % comm_size_ != 0) { - throw std::runtime_error( - "Tensor size must be divisible by comm_size for all_to_all_single"); - } - - tracing_->recordEventWithInputOutput("all_to_all_single", rank_, {input}, - {output}); - - xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {input}); - - work->recordStart(); - - size_t chunk_size = input.numel() / comm_size_; - const auto data_type = getXcclDataType(input); - - onecclResult_t result = - xccl_api_->allToAll(input.data_ptr(), output.data_ptr(), chunk_size, - data_type, xccl_comm_, stream); - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL AllToAll failed", result); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL all_to_all_single is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::all_to_all_v_single( @@ -766,300 +400,39 @@ c10::intrusive_ptr TorchCommXCCL::all_to_all_v_single( const std::vector &output_split_sizes, const std::vector &input_split_sizes, bool async_op, const AllToAllvSingleOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(output); - ensureTensorContiguous(input); - - // Validate split sizes vectors - if (input_split_sizes.size() != static_cast(comm_size_)) { - throw std::runtime_error("input_split_sizes length must equal comm_size " - "for all_to_all_v_single"); - } - - if (output_split_sizes.size() != static_cast(comm_size_)) { - throw std::runtime_error("output_split_sizes length must equal comm_size " - "for all_to_all_v_single"); - } - - tracing_->recordEventWithInputOutput("all_to_all_v_single", rank_, {input}, - {output}); - - xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {input}); - - work->recordStart(); - - // Convert split sizes to arrays and calculate displacements - std::vector sendcounts(comm_size_); - std::vector recvcounts(comm_size_); - std::vector senddispls(comm_size_); - std::vector recvdispls(comm_size_); - - // Calculate the number of elements per slice along the first dimension - // For a tensor with shape [N, D1, D2, ..., Dk], each slice of size S along - // dim 0 contains S * D1 * D2 * ... * Dk elements - size_t elements_per_slice = input.numel() ? input.numel() / input.size(0) : 0; - const auto data_type = getXcclDataType(input); - const size_t type_size = wordSize(data_type); - - size_t sendoffset = 0; - size_t recvoffset = 0; - for (int i = 0; i < comm_size_; ++i) { - sendcounts[i] = input_split_sizes[i] * elements_per_slice; - recvcounts[i] = output_split_sizes[i] * elements_per_slice; - senddispls[i] = sendoffset; - recvdispls[i] = recvoffset; - sendoffset += sendcounts[i]; - recvoffset += recvcounts[i]; - } - - char *sptr = static_cast(input.data_ptr()); - char *rptr = static_cast(output.data_ptr()); - - xccl_api_->groupStart(); - - for (int i = 0; i < comm_size_; ++i) { - xccl_api_->send(sptr + senddispls[i] * type_size, sendcounts[i], data_type, - i, xccl_comm_, stream); - xccl_api_->recv(rptr + recvdispls[i] * type_size, recvcounts[i], data_type, - i, xccl_comm_, stream); - } - - xccl_api_->groupEnd(); - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL all_to_all_v_single is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::all_to_all(const std::vector &output_tensor_list, const std::vector &input_tensor_list, bool async_op, const AllToAllOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - if (output_tensor_list.size() != static_cast(comm_size_) || - input_tensor_list.size() != static_cast(comm_size_)) { - throw std::runtime_error( - "Tensor list sizes must equal comm_size for all_to_all"); - } - - // Validate all tensors - for (int i = 0; i < comm_size_; ++i) { - ensureTensorContiguous(input_tensor_list[i]); - ensureTensorContiguous(output_tensor_list[i]); - } - - tracing_->recordEventWithInputOutput("all_to_all", rank_, input_tensor_list, - output_tensor_list); - - xpuStream_t stream = getOperationStream(async_op); - auto work = - createWork(stream, getOperationTimeout(options.timeout, options_.timeout), - input_tensor_list); - - work->recordStart(); - - xccl_api_->groupStart(); - - for (int i = 0; i < comm_size_; ++i) { - // Send to rank i - xccl_api_->send( - input_tensor_list[i].data_ptr(), input_tensor_list[i].numel(), - getXcclDataType(input_tensor_list[i]), i, xccl_comm_, stream); - - // Receive from rank i - xccl_api_->recv( - output_tensor_list[i].data_ptr(), output_tensor_list[i].numel(), - getXcclDataType(output_tensor_list[i]), i, xccl_comm_, stream); - } - - xccl_api_->groupEnd(); - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL all_to_all is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::barrier(bool async_op, const BarrierOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - - tracing_->recordEvent("barrier"); - xpuStream_t stream = getOperationStream(async_op); - auto work = createWork( - stream, getOperationTimeout(options.timeout, options_.timeout), {}); - - work->recordStart(); - - // Use pre-allocated XPU buffer for barrier - onecclResult_t result = - xccl_api_->allReduce(barrier_buffer_, barrier_buffer_, 1, onecclFloat32, - onecclSum, xccl_comm_, stream); - - if (result != onecclSuccess) { - throw XCCLException(*xccl_api_, "XCCL Barrier failed", result); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL barrier is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::scatter(at::Tensor &output_tensor, const std::vector &input_tensor_list, int root, bool async_op, const ScatterOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(output_tensor); - - // Only the root rank needs valid input tensors - if (rank_ == root) { - if (input_tensor_list.size() != static_cast(comm_size_)) { - throw std::runtime_error( - "input_tensor_list size must equal comm_size for scatter"); - } - - for (const auto &t : input_tensor_list) { - ensureTensorContiguous(t); - if (t.numel() != output_tensor.numel()) { - throw std::runtime_error( - "All input tensors must have same size as output tensor"); - } - } - } - - tracing_->recordEventWithInputOutput("scatter", root, input_tensor_list, - {output_tensor}); - - xpuStream_t stream = getOperationStream(async_op); - std::vector input_tensors; - if (rank_ == root) { - input_tensors = input_tensor_list; - } - auto work = - createWork(stream, getOperationTimeout(options.timeout, options_.timeout), - input_tensors); - - work->recordStart(); - - // Implement scatter using point-to-point operations - if (rank_ == root) { - // Root sends to all ranks (except itself) - xccl_api_->groupStart(); - for (int i = 0; i < comm_size_; ++i) { - if (i != root) { - xccl_api_->send( - input_tensor_list[i].data_ptr(), input_tensor_list[i].numel(), - getXcclDataType(input_tensor_list[i]), i, xccl_comm_, stream); - } - } - xccl_api_->groupEnd(); - - // Root copies its own data using xpuMemcpyAsync - XPU_CHECK(xpu_api_, - xpu_api_->memcpyAsync(output_tensor.data_ptr(), - input_tensor_list[root].data_ptr(), - input_tensor_list[root].numel() * - input_tensor_list[root].element_size(), - stream), - "memcpyAsync failed"); - } else { - // Non-root ranks receive from root - xccl_api_->recv(output_tensor.data_ptr(), output_tensor.numel(), - getXcclDataType(output_tensor), root, xccl_comm_, stream); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL scatter is not supported now and will be added later"); } c10::intrusive_ptr TorchCommXCCL::gather(const std::vector &output_tensor_list, const at::Tensor &input_tensor, int root, bool async_op, const GatherOptions &options) { - checkInitialized(); - checkAndAbortIfTimedOutOrError(); - ensureTensorContiguous(input_tensor); - - // Only the root rank needs valid output tensors - if (rank_ == root) { - if (output_tensor_list.size() != static_cast(comm_size_)) { - throw std::runtime_error( - "output_tensor_list size must equal comm_size for gather"); - } - - for (const auto &t : output_tensor_list) { - ensureTensorContiguous(t); - if (t.numel() != input_tensor.numel()) { - throw std::runtime_error( - "All output tensors must have same size as input tensor"); - } - } - } - - tracing_->recordEventWithInputOutput("gather", root, {input_tensor}, - output_tensor_list); - - xpuStream_t stream = getOperationStream(async_op); - std::vector output_tensors; - if (rank_ == root) { - output_tensors = output_tensor_list; - } - auto work = - createWork(stream, getOperationTimeout(options.timeout, options_.timeout), - {input_tensor}); - - work->recordStart(); - - if (rank_ == root) { - // Root receives from all ranks (except itself) - xccl_api_->groupStart(); - for (int i = 0; i < comm_size_; ++i) { - if (i != root) { - xccl_api_->recv( - output_tensor_list[i].data_ptr(), output_tensor_list[i].numel(), - getXcclDataType(output_tensor_list[i]), i, xccl_comm_, stream); - } - } - xccl_api_->groupEnd(); - - // Root copies its own data using xpuMemcpyAsync - XPU_CHECK(xpu_api_, - xpu_api_->memcpyAsync( - output_tensor_list[root].data_ptr(), input_tensor.data_ptr(), - input_tensor.numel() * input_tensor.element_size(), stream), - "memcpyAsync failed"); - } else { - // Non-root ranks send to root - xccl_api_->send(input_tensor.data_ptr(), input_tensor.numel(), - getXcclDataType(input_tensor), root, xccl_comm_, stream); - } - - work->recordEnd(); - - enqueueWork(work, stream); - - return work; + throw std::runtime_error("XCCL gather is not supported now and will be added later"); } std::shared_ptr TorchCommXCCL::split(const std::vector &ranks, const std::string &name, const CommOptions &options) { - throw std::runtime_error("Split is not supported now in XCCL"); + throw std::runtime_error("XCCL split is not supported now and will be added later"); } XCCLException::XCCLException(XcclApi &xccl_api, const std::string &message, From 5e8b85ad22d0a1b69e20c322dd99057798d07ff5 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 10 Nov 2025 14:11:08 +0800 Subject: [PATCH 09/10] Revert "support xccl test" This reverts commit 2ce3df27b98d52f4d51e37b5359effa6b84da299. --- .../tests/integration/py/AllGatherSingleTest.py | 2 -- comms/torchcomms/tests/integration/py/AllGatherTest.py | 2 -- comms/torchcomms/tests/integration/py/AllGatherVTest.py | 2 -- comms/torchcomms/tests/integration/py/AllReduceTest.py | 3 --- .../torchcomms/tests/integration/py/AllToAllSingleTest.py | 2 -- comms/torchcomms/tests/integration/py/AllToAllTest.py | 2 -- comms/torchcomms/tests/integration/py/BroadcastTest.py | 2 -- comms/torchcomms/tests/integration/py/GatherTest.py | 2 -- comms/torchcomms/tests/integration/py/ObjColTest.py | 3 --- .../tests/integration/py/ReduceScatterSingleTest.py | 3 --- .../torchcomms/tests/integration/py/ReduceScatterTest.py | 3 --- comms/torchcomms/tests/integration/py/ReduceTest.py | 3 --- comms/torchcomms/tests/integration/py/ScatterTest.py | 2 -- .../tests/integration/py/TorchCommTestHelpers.py | 8 +++----- 14 files changed, 3 insertions(+), 36 deletions(-) diff --git a/comms/torchcomms/tests/integration/py/AllGatherSingleTest.py b/comms/torchcomms/tests/integration/py/AllGatherSingleTest.py index fe646c8f..35e26cc1 100644 --- a/comms/torchcomms/tests/integration/py/AllGatherSingleTest.py +++ b/comms/torchcomms/tests/integration/py/AllGatherSingleTest.py @@ -18,8 +18,6 @@ class AllGatherSingleTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/AllGatherTest.py b/comms/torchcomms/tests/integration/py/AllGatherTest.py index 01f4d5f4..8c76aeec 100644 --- a/comms/torchcomms/tests/integration/py/AllGatherTest.py +++ b/comms/torchcomms/tests/integration/py/AllGatherTest.py @@ -18,8 +18,6 @@ class AllGatherTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/AllGatherVTest.py b/comms/torchcomms/tests/integration/py/AllGatherVTest.py index eda57a38..f971144d 100644 --- a/comms/torchcomms/tests/integration/py/AllGatherVTest.py +++ b/comms/torchcomms/tests/integration/py/AllGatherVTest.py @@ -19,8 +19,6 @@ class AllGatherVTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] num_replays = 4 def get_wrapper(self): diff --git a/comms/torchcomms/tests/integration/py/AllReduceTest.py b/comms/torchcomms/tests/integration/py/AllReduceTest.py index 6bb2ce03..81188429 100644 --- a/comms/torchcomms/tests/integration/py/AllReduceTest.py +++ b/comms/torchcomms/tests/integration/py/AllReduceTest.py @@ -21,9 +21,6 @@ class AllReduceTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] - dtypes = [torch.float, torch.int] ops = [ReduceOp.SUM, ReduceOp.MAX] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/AllToAllSingleTest.py b/comms/torchcomms/tests/integration/py/AllToAllSingleTest.py index 0be58374..7c4934d1 100644 --- a/comms/torchcomms/tests/integration/py/AllToAllSingleTest.py +++ b/comms/torchcomms/tests/integration/py/AllToAllSingleTest.py @@ -18,8 +18,6 @@ class AllToAllSingleTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/AllToAllTest.py b/comms/torchcomms/tests/integration/py/AllToAllTest.py index 11daa5de..4a8b73eb 100644 --- a/comms/torchcomms/tests/integration/py/AllToAllTest.py +++ b/comms/torchcomms/tests/integration/py/AllToAllTest.py @@ -18,8 +18,6 @@ class AllToAllTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/BroadcastTest.py b/comms/torchcomms/tests/integration/py/BroadcastTest.py index 290767d4..8a21de03 100644 --- a/comms/torchcomms/tests/integration/py/BroadcastTest.py +++ b/comms/torchcomms/tests/integration/py/BroadcastTest.py @@ -18,8 +18,6 @@ class BroadcastTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/GatherTest.py b/comms/torchcomms/tests/integration/py/GatherTest.py index 158596bb..496397f7 100644 --- a/comms/torchcomms/tests/integration/py/GatherTest.py +++ b/comms/torchcomms/tests/integration/py/GatherTest.py @@ -18,8 +18,6 @@ class GatherTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ObjColTest.py b/comms/torchcomms/tests/integration/py/ObjColTest.py index 30028874..28e6e897 100644 --- a/comms/torchcomms/tests/integration/py/ObjColTest.py +++ b/comms/torchcomms/tests/integration/py/ObjColTest.py @@ -2,7 +2,6 @@ # pyre-unsafe # Copyright (c) Meta Platforms, Inc. and affiliates. -import os import unittest from contextlib import contextmanager @@ -30,8 +29,6 @@ class ObjColTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ReduceScatterSingleTest.py b/comms/torchcomms/tests/integration/py/ReduceScatterSingleTest.py index 027360f2..508f50e6 100644 --- a/comms/torchcomms/tests/integration/py/ReduceScatterSingleTest.py +++ b/comms/torchcomms/tests/integration/py/ReduceScatterSingleTest.py @@ -21,9 +21,6 @@ class ReduceScatterSingleTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] - dtypes = [torch.float, torch.int] ops = [ReduceOp.SUM, ReduceOp.MAX] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ReduceScatterTest.py b/comms/torchcomms/tests/integration/py/ReduceScatterTest.py index 19831a76..f448a80a 100644 --- a/comms/torchcomms/tests/integration/py/ReduceScatterTest.py +++ b/comms/torchcomms/tests/integration/py/ReduceScatterTest.py @@ -21,9 +21,6 @@ class ReduceScatterTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] - dtypes = [torch.float, torch.int] ops = [ReduceOp.SUM, ReduceOp.MAX] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ReduceTest.py b/comms/torchcomms/tests/integration/py/ReduceTest.py index 339d4172..c593c3e5 100644 --- a/comms/torchcomms/tests/integration/py/ReduceTest.py +++ b/comms/torchcomms/tests/integration/py/ReduceTest.py @@ -21,9 +21,6 @@ class ReduceTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] - dtypes = [torch.float, torch.int] ops = [ReduceOp.SUM, ReduceOp.MAX] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/ScatterTest.py b/comms/torchcomms/tests/integration/py/ScatterTest.py index 70a629ea..e606675d 100644 --- a/comms/torchcomms/tests/integration/py/ScatterTest.py +++ b/comms/torchcomms/tests/integration/py/ScatterTest.py @@ -18,8 +18,6 @@ class ScatterTest(unittest.TestCase): # Class variables for test parameters counts = [0, 4, 1024, 1024 * 1024] - if os.environ.get("TEST_BACKEND") == "xccl": - counts = [4, 1024, 1024 * 1024] dtypes = [torch.float, torch.int, torch.int8] num_replays = 4 diff --git a/comms/torchcomms/tests/integration/py/TorchCommTestHelpers.py b/comms/torchcomms/tests/integration/py/TorchCommTestHelpers.py index 44da2f05..b6a955c2 100644 --- a/comms/torchcomms/tests/integration/py/TorchCommTestHelpers.py +++ b/comms/torchcomms/tests/integration/py/TorchCommTestHelpers.py @@ -218,13 +218,11 @@ def get_device(self, backend, rank): if device := os.environ.get("TEST_DEVICE"): return torch.device(device) - # Check for accelerator availability and abort if not available - if not torch.accelerator.is_available(): + # Check for CUDA availability and abort if not available + if not torch.cuda.is_available(): return torch.device("cpu") - device_id = rank % torch.accelerator.device_count() - if os.getenv("TEST_BACKEND") == "xccl": - return torch.device(f"xpu:{device_id}") + device_id = rank % torch.cuda.device_count() return torch.device(f"cuda:{device_id}") def __init__(self, store=None): From 6f2695bdde7c71dcf6201bd1c4167bdff9837669 Mon Sep 17 00:00:00 2001 From: "Han, Chao" Date: Mon, 10 Nov 2025 14:29:58 +0800 Subject: [PATCH 10/10] Add env check --- comms/torchcomms/xccl/CMakeLists.txt | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/comms/torchcomms/xccl/CMakeLists.txt b/comms/torchcomms/xccl/CMakeLists.txt index 88eef1d4..b8726888 100644 --- a/comms/torchcomms/xccl/CMakeLists.txt +++ b/comms/torchcomms/xccl/CMakeLists.txt @@ -1,12 +1,28 @@ # Extension: torchcomms._comms_xccl -file(GLOB TORCHCOMMS_XCCL_SOURCES "comms/torchcomms/xccl/*.cpp") -file(GLOB TORCHCOMMS_XPU_API_SOURCE "comms/torchcomms/device/XpuApi.cpp") +# Check if CCL_ROOT is set +if(NOT DEFINED ENV{CCL_ROOT}) + message(WARNING "oneCCL environment not found, Skipping XCCL backend compilation.") + return() +endif() + +# Set XCCL paths set(XCCL_INCLUDE "$ENV{CCL_ROOT}/include") set(XCCL_SHARED_LIB "$ENV{CCL_ROOT}/lib/libccl.so.2") -include(FindPackageHandleStandardArgs) +# Validate oneCCL installation +if(NOT EXISTS "${XCCL_INCLUDE}" OR NOT EXISTS "${XCCL_SHARED_LIB}") + message(WARNING "Invalid oneCCL path. Skipping XCCL backend compilation.") + return() +endif() + +message(STATUS "XCCL include path : ${XCCL_INCLUDE}") +message(STATUS "XCCL library : ${XCCL_SHARED_LIB}") +file(GLOB TORCHCOMMS_XCCL_SOURCES "comms/torchcomms/xccl/*.cpp") +file(GLOB TORCHCOMMS_XPU_API_SOURCE "comms/torchcomms/device/XpuApi.cpp") + +include(FindPackageHandleStandardArgs) add_library(torchcomms_comms_xccl MODULE ${TORCHCOMMS_XCCL_SOURCES}