diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c9ca1cf..5a2dc42f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,11 +11,13 @@ option(USE_NCCLX "Whether to build NCCLX or not" ON) option(USE_GLOO "Whether to build Gloo or not" ON) option(USE_RCCL "Whether to build RCCL or not" OFF) option(USE_RCCLX "Whether to build RCCLX or not" OFF) +option(USE_XCCL "Whether to build XCCL or not" OFF) message(STATUS " USE_NCCL : ${USE_NCCL}") message(STATUS " USE_NCCLX : ${USE_NCCLX}") message(STATUS " USE_GLOO : ${USE_GLOO}") message(STATUS " USE_RCCL : ${USE_RCCL}") message(STATUS " USE_RCCLX : ${USE_RCCLX}") +message(STATUS " USE_XCCL : ${USE_XCCL}") # Find Python and PyTorch find_package(Python3 COMPONENTS Interpreter Development.Module REQUIRED) find_package(Torch REQUIRED) @@ -112,6 +114,8 @@ if (USE_RCCL) endif() if (USE_RCCLX) include(comms/torchcomms/rcclx/CMakeLists.txt) +if (USE_XCCL) + include(comms/torchcomms/xccl/CMakeLists.txt) endif() # Install targets to Python package structure diff --git a/comms/torchcomms/device/XpuApi.cpp b/comms/torchcomms/device/XpuApi.cpp new file mode 100644 index 00000000..cdccd4ae --- /dev/null +++ b/comms/torchcomms/device/XpuApi.cpp @@ -0,0 +1,315 @@ +#include "comms/torchcomms/device/XpuApi.hpp" +#include +#include +#include +#include +#include + +namespace torch { +namespace comms { + +xpu_result_t DefaultXpuApi::setDevice(int device) { + try { + ::c10::xpu::set_device(device); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::getDeviceProperties(xpuDeviceProp* prop, int device) { + if (!prop) { + return XPU_ERROR_INVALID_VALUE; + } + + try { + sycl::device sycl_device = ::c10::xpu::get_raw_device(device); + + // Get device name + std::string device_name = sycl_device.get_info(); + strncpy(prop->name, device_name.c_str(), 255); + prop->name[255] = '\0'; + + // Get memory info + prop->totalGlobalMem = sycl_device.get_info(); + + // Get compute capabilities + auto max_work_group_size = sycl_device.get_info(); + auto max_work_item_sizes = sycl_device.get_info>(); + auto max_compute_units = sycl_device.get_info(); + + prop->multiProcessorCount = max_compute_units; + prop->maxThreadsPerBlock = max_work_group_size; + prop->maxThreadsDim[0] = max_work_item_sizes[0]; + prop->maxThreadsDim[1] = max_work_item_sizes[1]; + prop->maxThreadsDim[2] = max_work_item_sizes[2]; + + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::memGetInfo(size_t* free, size_t* total) { + if (!free || !total) { + return XPU_ERROR_INVALID_VALUE; + } + + try { + int device = ::c10::xpu::current_device(); + sycl::device& sycl_device = ::c10::xpu::get_raw_device(device); + + *total = sycl_device.get_info(); + *free = *total; // SYCL doesn't provide free memory query + + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::getDeviceCount(int* count) { + if (!count) { + return XPU_ERROR_INVALID_VALUE; + } + + try { + *count = ::c10::xpu::device_count(); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::streamCreateWithPriority( + xpuStream_t& stream, + unsigned int flags, + int priority) { + try { + // Map priority: priority < 0 = high, priority >= 0 = normal + bool isHighPriority = (priority < 0); + stream = ::c10::xpu::getStreamFromPool(isHighPriority); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::streamDestroy(const xpuStream_t& stream) { + // Stream is managed by PyTorch, nothing to do + return XPU_SUCCESS; +} + +xpu_result_t DefaultXpuApi::streamWaitEvent( + const xpuStream_t& stream, + xpuEvent_t& event, + unsigned int flags) { + try { + event.block(stream); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_HANDLE; + } +} + +xpuStream_t DefaultXpuApi::getCurrentXPUStream(int device_index) { + return ::c10::xpu::getCurrentXPUStream(device_index); +} + +xpu_result_t DefaultXpuApi::streamSynchronize(const xpuStream_t& stream) { + try { + stream.queue().wait_and_throw(); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_HANDLE; + } +} + +xpu_result_t DefaultXpuApi::streamIsCapturing( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus) { + if (!pCaptureStatus) { + return XPU_ERROR_INVALID_VALUE; + } + + // XPU/SYCL doesn't support stream capture + *pCaptureStatus = xpuStreamCaptureStatusNone; + return XPU_SUCCESS; +} + +xpu_result_t DefaultXpuApi::streamGetCaptureInfo( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus, + unsigned long long* pId) { + if (!pCaptureStatus) { + return XPU_ERROR_INVALID_VALUE; + } + + *pCaptureStatus = xpuStreamCaptureStatusNone; + if (pId) { + *pId = 0; + } + return XPU_SUCCESS; +} + +xpu_result_t DefaultXpuApi::malloc(void** devPtr, size_t size) { + if (!devPtr) { + return XPU_ERROR_INVALID_VALUE; + } + + if (size == 0) { + *devPtr = nullptr; + return XPU_SUCCESS; + } + + try { + // Use SYCL's malloc_device + sycl::context& ctx = ::c10::xpu::get_device_context(); + int device = ::c10::xpu::current_device(); + sycl::device& dev = ::c10::xpu::get_raw_device(device); + + *devPtr = sycl::malloc_device(size, dev, ctx); + + if (!*devPtr) { + return XPU_ERROR_OUT_OF_MEMORY; + } + + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_OUT_OF_MEMORY; + } +} + +xpu_result_t DefaultXpuApi::free(void* devPtr) { + if (!devPtr) { + return XPU_SUCCESS; + } + + try { + sycl::context& ctx = ::c10::xpu::get_device_context(); + sycl::free(devPtr, ctx); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::memcpyAsync( + void* dst, + const void* src, + size_t count, + const xpuStream_t& stream) { + if (!dst || !src) { + return XPU_ERROR_INVALID_VALUE; + } + + if (count == 0) { + return XPU_SUCCESS; + } + + try { + stream.queue().memcpy(dst, src, count); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + + +xpu_result_t DefaultXpuApi::eventCreate(xpuEvent_t& event) { + try { + event = ::at::xpu::XPUEvent(false); // No timing + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::eventCreateWithFlags( + xpuEvent_t& event, + unsigned int flags) { + try { + bool enable_timing = (flags & 0x1) != 0; + event = ::at::xpu::XPUEvent(enable_timing); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_VALUE; + } +} + +xpu_result_t DefaultXpuApi::eventDestroy(const xpuEvent_t& event) { + // Event is RAII, nothing to do + return XPU_SUCCESS; +} + +xpu_result_t DefaultXpuApi::eventRecord(xpuEvent_t& event, const xpuStream_t& stream) { + try { + event.record(stream); + return XPU_SUCCESS; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_HANDLE; + } +} + +xpu_result_t DefaultXpuApi::eventQuery(const xpuEvent_t& event) { + try { + bool is_complete = event.query(); + return is_complete ? XPU_SUCCESS : XPU_ERROR_NOT_READY; + } catch (const std::exception& e) { + return XPU_ERROR_INVALID_HANDLE; + } +} + +// Graph Operations (Unsupported) +xpu_result_t DefaultXpuApi::userObjectCreate( + xpuUserObject_t* object_out, + void* ptr, + xpuHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags) { + // XPU/SYCL doesn't support user objects + return XPU_ERROR_UNSUPPORTED; +} + +xpu_result_t DefaultXpuApi::graphRetainUserObject( + xpuGraph_t graph, + xpuUserObject_t object, + unsigned int count, + unsigned int flags) { + // Currently, XPU/SYCL doesn't support graphs + return XPU_ERROR_UNSUPPORTED; +} + +xpu_result_t DefaultXpuApi::streamGetCaptureInfo_v2( + const xpuStream_t& stream, + xpuStreamCaptureStatus* captureStatus_out, + unsigned long long* id_out, + xpuGraph_t* graph_out, + const xpuGraphNode_t** dependencies_out, + size_t* numDependencies_out) { + // Currently, XPU/SYCL doesn't support graphs + return XPU_ERROR_UNSUPPORTED; +} + +// Error Handling +const char* DefaultXpuApi::getErrorString(xpu_result_t error) { + switch (error) { + case XPU_SUCCESS: + return "success"; + case XPU_ERROR_INVALID_VALUE: + return "invalid value"; + case XPU_ERROR_NOT_READY: + return "not ready"; + case XPU_ERROR_INVALID_HANDLE: + return "invalid handle"; + case XPU_ERROR_OUT_OF_MEMORY: + return "out of memory"; + case XPU_ERROR_UNSUPPORTED: + return "unsupported feature"; + default: + return "unknown error"; + } +} + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/device/XpuApi.hpp b/comms/torchcomms/device/XpuApi.hpp new file mode 100644 index 00000000..b07d539d --- /dev/null +++ b/comms/torchcomms/device/XpuApi.hpp @@ -0,0 +1,202 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { +namespace comms { + +using xpuStream_t = ::c10::xpu::XPUStream; +using xpuEvent_t = ::at::xpu::XPUEvent; + +struct xpuDeviceProp { + char name[256]; + size_t totalGlobalMem; + int multiProcessorCount; + int maxThreadsPerBlock; + int maxThreadsDim[3]; +}; + +// Graph-related types (placeholder - unsupported in XPU) +using xpuGraph_t = void*; +using xpuGraphNode_t = void*; +using xpuUserObject_t = void*; +using xpuHostFn_t = void(*)(void*); + +// Stream capture status (not supported in XPU) +enum xpuStreamCaptureStatus { + xpuStreamCaptureStatusNone = 0, +}; + +// Error code type +using xpu_result_t = int32_t; +constexpr xpu_result_t XPU_SUCCESS = 0; +constexpr xpu_result_t XPU_ERROR_INVALID_VALUE = 1; +constexpr xpu_result_t XPU_ERROR_NOT_READY = 2; +constexpr xpu_result_t XPU_ERROR_INVALID_HANDLE = 3; +constexpr xpu_result_t XPU_ERROR_OUT_OF_MEMORY = 4; +constexpr xpu_result_t XPU_ERROR_UNSUPPORTED = 5; + +#define XPU_CHECK(xpu_api, call, err_str) \ + do { \ + xpu_result_t status = call; \ + if (status != XPU_SUCCESS) { \ + std::stringstream ss; \ + ss << err_str << ": " << xpu_api->getErrorString(status) << " at " \ + << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(ss.str()); \ + } \ + } while (0) + +/** + * Abstract interface for XPU API operations. + * This allows for dependency injection and testing by providing + * a way to override XPU API calls. + */ +class XpuApi { + public: + virtual ~XpuApi() = default; + + // Device management + virtual xpu_result_t setDevice(int device) = 0; + virtual xpu_result_t getDeviceProperties(xpuDeviceProp* prop, int device) = 0; + virtual xpu_result_t memGetInfo(size_t* free, size_t* total) = 0; + virtual xpu_result_t getDeviceCount(int* count) = 0; + + // Stream management + virtual xpu_result_t streamCreateWithPriority( + xpuStream_t& stream, + unsigned int flags, + int priority) = 0; + virtual xpu_result_t streamDestroy(const xpuStream_t& stream) = 0; + virtual xpu_result_t streamWaitEvent( + const xpuStream_t& stream, + xpuEvent_t& event, + unsigned int flags) = 0; + virtual xpuStream_t getCurrentXPUStream(int device_index) = 0; + virtual xpu_result_t streamSynchronize(const xpuStream_t& stream) = 0; + virtual xpu_result_t streamIsCapturing( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus) = 0; + virtual xpu_result_t streamGetCaptureInfo( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus, + unsigned long long* pId) = 0; + + // Memory management + virtual xpu_result_t malloc(void** devPtr, size_t size) = 0; + virtual xpu_result_t free(void* devPtr) = 0; + virtual xpu_result_t memcpyAsync( + void* dst, + const void* src, + size_t count, + const xpuStream_t& stream) = 0; + + // Event management + virtual xpu_result_t eventCreate(xpuEvent_t& event) = 0; + virtual xpu_result_t eventCreateWithFlags( + xpuEvent_t& event, + unsigned int flags) = 0; + virtual xpu_result_t eventDestroy(const xpuEvent_t& event) = 0; + virtual xpu_result_t eventRecord(xpuEvent_t& event, const xpuStream_t& stream) = 0; + virtual xpu_result_t eventQuery(const xpuEvent_t& event) = 0; + + // Graph operations (unsupported, kept for API compatibility) + virtual xpu_result_t userObjectCreate( + xpuUserObject_t* object_out, + void* ptr, + xpuHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags) = 0; + virtual xpu_result_t graphRetainUserObject( + xpuGraph_t graph, + xpuUserObject_t object, + unsigned int count, + unsigned int flags) = 0; + virtual xpu_result_t streamGetCaptureInfo_v2( + const xpuStream_t& stream, + xpuStreamCaptureStatus* captureStatus_out, + unsigned long long* id_out, + xpuGraph_t* graph_out, + const xpuGraphNode_t** dependencies_out, + size_t* numDependencies_out) = 0; + + // Error handling + virtual const char* getErrorString(xpu_result_t error) = 0; +}; + +class DefaultXpuApi : public XpuApi { + public: + ~DefaultXpuApi() override = default; + + // Device management + xpu_result_t setDevice(int device) override; + xpu_result_t getDeviceProperties(xpuDeviceProp* prop, int device) override; + xpu_result_t memGetInfo(size_t* free, size_t* total) override; + xpu_result_t getDeviceCount(int* count) override; + + // Stream management + xpu_result_t streamCreateWithPriority( + xpuStream_t& stream, + unsigned int flags, + int priority) override; + xpu_result_t streamDestroy(const xpuStream_t& stream) override; + xpu_result_t streamWaitEvent( + const xpuStream_t& stream, + xpuEvent_t& event, + unsigned int flags) override; + xpuStream_t getCurrentXPUStream(int device_index) override; + xpu_result_t streamSynchronize(const xpuStream_t& stream) override; + xpu_result_t streamIsCapturing( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus) override; + xpu_result_t streamGetCaptureInfo( + const xpuStream_t& stream, + xpuStreamCaptureStatus* pCaptureStatus, + unsigned long long* pId) override; + + // Memory management + xpu_result_t malloc(void** devPtr, size_t size) override; + xpu_result_t free(void* devPtr) override; + xpu_result_t memcpyAsync( + void* dst, + const void* src, + size_t count, + const xpuStream_t& stream) override; + + // Event management + xpu_result_t eventCreate(xpuEvent_t& event) override; + xpu_result_t eventCreateWithFlags(xpuEvent_t& event, unsigned int flags) override; + xpu_result_t eventDestroy(const xpuEvent_t& event) override; + xpu_result_t eventRecord(xpuEvent_t& event, const xpuStream_t& stream) override; + xpu_result_t eventQuery(const xpuEvent_t& event) override; + + // Graph operations (unsupported) + xpu_result_t userObjectCreate( + xpuUserObject_t* object_out, + void* ptr, + xpuHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags) override; + xpu_result_t graphRetainUserObject( + xpuGraph_t graph, + xpuUserObject_t object, + unsigned int count, + unsigned int flags) override; + xpu_result_t streamGetCaptureInfo_v2( + const xpuStream_t& stream, + xpuStreamCaptureStatus* captureStatus_out, + unsigned long long* id_out, + xpuGraph_t* graph_out, + const xpuGraphNode_t** dependencies_out, + size_t* numDependencies_out) override; + + // Error handling + const char* getErrorString(xpu_result_t error) override; +}; + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/CMakeLists.txt b/comms/torchcomms/xccl/CMakeLists.txt new file mode 100644 index 00000000..b8726888 --- /dev/null +++ b/comms/torchcomms/xccl/CMakeLists.txt @@ -0,0 +1,59 @@ +# Extension: torchcomms._comms_xccl + +# Check if CCL_ROOT is set +if(NOT DEFINED ENV{CCL_ROOT}) + message(WARNING "oneCCL environment not found, Skipping XCCL backend compilation.") + return() +endif() + +# Set XCCL paths +set(XCCL_INCLUDE "$ENV{CCL_ROOT}/include") +set(XCCL_SHARED_LIB "$ENV{CCL_ROOT}/lib/libccl.so.2") + +# Validate oneCCL installation +if(NOT EXISTS "${XCCL_INCLUDE}" OR NOT EXISTS "${XCCL_SHARED_LIB}") + message(WARNING "Invalid oneCCL path. Skipping XCCL backend compilation.") + return() +endif() + +message(STATUS "XCCL include path : ${XCCL_INCLUDE}") +message(STATUS "XCCL library : ${XCCL_SHARED_LIB}") + +file(GLOB TORCHCOMMS_XCCL_SOURCES "comms/torchcomms/xccl/*.cpp") +file(GLOB TORCHCOMMS_XPU_API_SOURCE "comms/torchcomms/device/XpuApi.cpp") + +include(FindPackageHandleStandardArgs) + +add_library(torchcomms_comms_xccl MODULE + ${TORCHCOMMS_XCCL_SOURCES} + ${TORCHCOMMS_XPU_API_SOURCE} +) +set_target_properties(torchcomms_comms_xccl PROPERTIES + PREFIX "" + OUTPUT_NAME "_comms_xccl" + SUFFIX ".${Python3_SOABI}.so" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/comms/torchcomms" + BUILD_RPATH "$ORIGIN" + INSTALL_RPATH "$ORIGIN" +) +target_include_directories(torchcomms_comms_xccl PRIVATE + ${ROOT} + ${XCCL_INCLUDE} + ${CONDA_INCLUDE} + ${Python3_INCLUDE_DIRS} +) +target_compile_features(torchcomms_comms_xccl PRIVATE cxx_std_20) +target_link_directories(torchcomms_comms_xccl PRIVATE ${CONDA_LIB}) +target_link_libraries(torchcomms_comms_xccl PRIVATE + ${TORCH_LIBRARIES} + ${TORCH_PYTHON_LIB} + torchcomms +) + +target_link_libraries(torchcomms_comms_xccl PRIVATE + ${XCCL_SHARED_LIB} +) + +install(TARGETS torchcomms_comms_xccl + LIBRARY DESTINATION . +) diff --git a/comms/torchcomms/xccl/TorchCommXCCL.cpp b/comms/torchcomms/xccl/TorchCommXCCL.cpp new file mode 100644 index 00000000..b0c7bf82 --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCL.cpp @@ -0,0 +1,459 @@ +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" + +#include "comms/torchcomms/TorchCommFactory.hpp" +#include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp" +#include +#include +#include +#include + +namespace torch { +namespace comms { + +onecclResult_t XCCLException::getResult() const { return result_; } + +TorchCommXCCL::TorchCommXCCL() + : xccl_comm_{nullptr}, device_(at::kXPU), + init_state_(InitializationState::UNINITIALIZED), shutdown_(false) {} + +TorchCommXCCL::TorchCommXCCL(const onecclComm_t xccl_comm) + : xccl_comm_(xccl_comm), device_(at::kXPU), + init_state_(InitializationState::UNINITIALIZED), shutdown_(false) {} + +TorchCommXCCL::~TorchCommXCCL() { + if (init_state_ == InitializationState::INITIALIZED) { + TC_LOG(ERROR) << "TorchCommXCCL was not finalized before destruction"; + + // If finalize was not called, we need to clean up the timeout thread + if (timeout_thread_.joinable()) { + shutdown_.store(true); + timeout_thread_.join(); + } + } +} + +void TorchCommXCCL::init(at::Device device, const std::string &name, + const CommOptions &options) { + // Initialize private members + device_ = device; + name_ = name; + options_ = options; + + // Only initialize once + if (init_state_ == InitializationState::INITIALIZED) { + throw std::runtime_error("TorchCommXCCL already initialized"); + } else if (init_state_ == InitializationState::FINALIZED) { + throw std::runtime_error("TorchCommXCCL already finalized"); + } + init_state_ = InitializationState::INITIALIZED; + + // Initialize default XCCL API implementation if not already set + if (!xccl_api_) { + xccl_api_ = std::make_unique(); + } + + // Initialize default XPU API implementation if not already set + if (!xpu_api_) { + xpu_api_ = std::make_unique(); + } + + if (device_.index() == -1 || xccl_comm_ == nullptr) { + auto bootstrap = new TorchCommXCCLBootstrap( + options_.store, device_, xccl_api_, xpu_api_, options_.timeout); + device_ = bootstrap->getDevice(); + + if (xccl_comm_ == nullptr) { + xccl_comm_ = bootstrap->createXcclComm(name, options); + } + + delete bootstrap; + } + + // Set XPU device and verify it' accessible + XPU_CHECK(xpu_api_, xpu_api_->setDevice(device_.index()), + "Failed to set XPU device to " + std::to_string(device_.index())); + + // Read hints and store them + for (auto const &[key, val] : options_.hints) { + if (key.starts_with("torchcomm::xccl::")) { + if (key == "torchcomm::xccl::high_priority_stream") { + high_priority_stream_ = string_to_bool(val); + } else { + throw std::runtime_error("Unrecognized hint " + key); + } + } else { + // Ignore keys that do not start with "torchcomm::xccl::" + } + } + + // Create internal stream + int stream_priority = 0; + + // Check for high priority stream hint + if (high_priority_stream_) { + stream_priority = -1; + } + + // Initialize internal stream + xpuStream_t temp_stream = xpu_api_->getCurrentXPUStream(device_.index()); + XPU_CHECK(xpu_api_, + xpu_api_->streamCreateWithPriority(temp_stream, /*flags=*/0, + stream_priority), + "Failed to create internal XPU stream on device " + + std::to_string(device_.index())); + internal_stream_ = std::move(temp_stream); + + // Create dependency event for stream synchronization + xpuEvent_t temp_event(/*enable_timing=*/false); + XPU_CHECK(xpu_api_, xpu_api_->eventCreateWithFlags(temp_event, /*flags=*/0), + "Failed to create dependency event on device " + + std::to_string(device_.index())); + dependency_event_ = std::move(temp_event); + + // Allocate XPU buffer for barrier operations + XPU_CHECK(xpu_api_, xpu_api_->malloc(&barrier_buffer_, sizeof(float)), + "Failed to allocate barrier buffer"); + + if (options_.hints.contains("torchcomm::xccl::max_event_pool_size")) { + max_event_pool_size_ = + std::stoull(options_.hints.at("torchcomm::xccl::max_event_pool_size")); + } else { + max_event_pool_size_ = kMaxEventPoolSize; + } + + // Give up our internal reference to the store object here. The caller + // would still need to keep a reference to the store object till the init + // call returns, at which point the XCCL communicator would already be + // created. + if (options_.store) { + options_.store.reset(); + } + + onecclResult_t xcclErr; + xcclErr = xccl_api_->commUserRank(xccl_comm_, &rank_); + if (xcclErr != onecclSuccess) { + throw std::runtime_error("XCCL User Rank failed"); + } + + tryTorchCommLoggingInit("torchcomm"); + + xcclErr = xccl_api_->commCount(xccl_comm_, &comm_size_); + if (xcclErr != onecclSuccess) { + throw std::runtime_error("XCCL Count failed"); + } + + tracing_ = std::make_shared(name, comm_size_, rank_); + tracing_->recordEvent("init"); + + // Start timeout watchdog thread + timeout_thread_ = std::thread(&TorchCommXCCL::timeoutWatchdog, this); +} + +void TorchCommXCCL::finalize() { + if (init_state_ == InitializationState::UNINITIALIZED) { + throw std::runtime_error("TorchCommXCCL not initialized"); + } else if (init_state_ == InitializationState::FINALIZED) { + throw std::runtime_error("TorchCommXCCL already finalized"); + } + init_state_ = InitializationState::FINALIZED; + + // Signal shutdown to timeout watchdog + shutdown_ = true; + + // Wake up the timeout watchdog thread + { + std::lock_guard lock(timeout_mutex_); + timeout_cv_.notify_all(); + } + + // Wait for timeout thread to finish + if (timeout_thread_.joinable()) { + timeout_thread_.join(); + } + + // Wait for all pending work objects to complete and get final status + auto work_status = workq_.finalize(); + + if (work_status == TorchWorkXCCL::WorkStatus::NOT_STARTED || + work_status == TorchWorkXCCL::WorkStatus::INPROGRESS) { + throw std::runtime_error( + "WorkQ finalize returned in progress or not started state"); + } + + // Update comm_state_ based on the work status + if (work_status == TorchWorkXCCL::WorkStatus::TIMEDOUT) { + comm_state_ = CommState::TIMEOUT; + abortXcclComm(); + throw std::runtime_error("Work timed out during finalize"); + } else if (work_status == TorchWorkXCCL::WorkStatus::ERROR) { + comm_state_ = CommState::ERROR; + onecclResult_t asyncErr; + xccl_api_->commGetAsyncError(xccl_comm_, &asyncErr); + XCCLException xcclException(*xccl_api_, "XCCL Async Error", asyncErr); + abortXcclComm(); + throw xcclException; + } + + // Clean up event pool + { + std::lock_guard lock(event_pool_mutex_); + while (!event_pool_.empty()) { + xpuEvent_t event = std::move(event_pool_.front()); + event_pool_.pop(); + XPU_CHECK(xpu_api_, xpu_api_->eventDestroy(event), + "Failed to destroy event"); + } + } + + // Free barrier buffer. TODO: handle errors on xpu free and stream destroy + if (barrier_buffer_) { + XPU_CHECK(xpu_api_, xpu_api_->free(barrier_buffer_), + "Failed to free barrier buffer"); + barrier_buffer_ = nullptr; + } + + // Destroy dependency event + if (dependency_event_.has_value()) { + XPU_CHECK(xpu_api_, xpu_api_->eventDestroy(dependency_event_.value()), + "Failed to destroy dependency event"); + dependency_event_.reset(); + } + + // Destroy internal stream + if (internal_stream_.has_value()) { + XPU_CHECK(xpu_api_, xpu_api_->streamDestroy(internal_stream_.value()), + "Failed to destroy internal stream"); + internal_stream_.reset(); + } + + // Destroy XCCL communicator + // TODO: should probably not call this after calling abort. + if (xccl_comm_) { + xccl_api_->commDestroy(xccl_comm_); + xccl_comm_ = nullptr; + } +} + +void TorchCommXCCL::abortXcclComm() { + if (xccl_comm_) { + xccl_api_->commAbort(xccl_comm_); + xccl_comm_ = nullptr; + } + if (options_.abort_process_on_timeout_or_error) { + TC_LOG(ERROR) << "Aborting process due to timeout"; + abort(); + } +} + +int TorchCommXCCL::getRank() const { + checkInitialized(); + + int rank; + onecclResult_t xcclErr = xccl_api_->commUserRank(xccl_comm_, &rank); + if (xcclErr != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL User Rank failed", xcclErr); + } + return rank; +} + +int TorchCommXCCL::getSize() const { + checkInitialized(); + + int comm_size; + onecclResult_t xcclErr = xccl_api_->commCount(xccl_comm_, &comm_size); + if (xcclErr != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL Count failed", xcclErr); + } + return comm_size; +} + +std::string_view TorchCommXCCL::getBackendName() const { return kBackendName; } + +std::string_view TorchCommXCCL::getCommName() const { return name_; } + +static inline std::chrono::milliseconds +getOperationTimeout(std::chrono::milliseconds timeout, + std::chrono::milliseconds default_timeout) { + // If timeout is kNoTimeout (0ms), use the default timeout from options + if (timeout == kNoTimeout) { + return default_timeout; + } + return timeout; +} + +// Point-to-Point Operations +c10::intrusive_ptr TorchCommXCCL::send(const at::Tensor &tensor, + int dst, bool async_op, + const SendOptions &options) { + throw std::runtime_error("XCCL send is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommXCCL::recv(at::Tensor &tensor, int src, + bool async_op, + const RecvOptions &options) { + throw std::runtime_error("XCCL recv is not supported now and will be added later"); +} + +// Batch P2P Operations +c10::intrusive_ptr +TorchCommXCCL::batch_op_issue(const std::vector &ops, + bool async_op, const BatchP2POptions &options) { + throw std::runtime_error("XCCL batch_op_issue is not supported now and will be added later"); +} + +// Collective Operations +c10::intrusive_ptr +TorchCommXCCL::broadcast(at::Tensor &tensor, int root, bool async_op, + const BroadcastOptions &options) { + throw std::runtime_error("XCCL broadcast is not supported now and will be added later"); +} + +c10::intrusive_ptr +TorchCommXCCL::all_reduce(at::Tensor &tensor, const ReduceOp &op, bool async_op, + const AllReduceOptions &options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + ensureTensorContiguous(tensor); + + tracing_->recordEventWithInputOutput("all_reduce", rank_, {tensor}, {tensor}); + + xpuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); + + work->recordStart(); + + const auto dataType = getXcclDataType(tensor); + onecclResult_t result = xccl_api_->allReduce( + tensor.data_ptr(), + tensor.data_ptr(), // In-place operation + tensor.numel(), dataType, getXcclReduceOp(op, xccl_comm_, dataType), + xccl_comm_, stream); + + if (result != onecclSuccess) { + throw XCCLException(*xccl_api_, "XCCL AllReduce failed", result); + } + + work->recordEnd(); + + enqueueWork(work, stream); + + return work; +} + +c10::intrusive_ptr TorchCommXCCL::reduce(const at::Tensor &tensor, + int root, const ReduceOp &op, + bool async_op, + const ReduceOptions &options) { + throw std::runtime_error("XCCL reduce is not supported now and will be added later"); +} + +c10::intrusive_ptr +TorchCommXCCL::all_gather(const std::vector &tensor_list, + const at::Tensor &tensor, bool async_op, + const AllGatherOptions &options) { + throw std::runtime_error("XCCL all_gather is not supported now and will be added later"); +} + +c10::intrusive_ptr +TorchCommXCCL::all_gather_v(const std::vector &tensor_list, + const at::Tensor &tensor, bool async_op, + const AllGatherOptions &options) { + throw std::runtime_error("all_gather_v is not supported in XCCL backend"); +} + +c10::intrusive_ptr +TorchCommXCCL::all_gather_single(at::Tensor &output, const at::Tensor &input, + bool async_op, + const AllGatherSingleOptions &options) { + throw std::runtime_error("XCCL all_gather_single is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommXCCL::reduce_scatter( + at::Tensor &output, const std::vector &input_list, const ReduceOp &op, + bool async_op, const ReduceScatterOptions &options) { + throw std::runtime_error("XCCL reduce_scatter is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommXCCL::reduce_scatter_v( + at::Tensor &output, const std::vector &input_list, + const ReduceOp &op, bool async_op, const ReduceScatterOptions &options) { + throw std::runtime_error("reduce_scatter_v is not supported in XCCL backend"); +} + +c10::intrusive_ptr TorchCommXCCL::reduce_scatter_single( + at::Tensor &output, const at::Tensor &input, const ReduceOp &op, bool async_op, + const ReduceScatterSingleOptions &options) { + throw std::runtime_error("XCCL reduce_scatter_single is not supported now and will be added later"); +} + +c10::intrusive_ptr +TorchCommXCCL::all_to_all_single(at::Tensor &output, const at::Tensor &input, + bool async_op, + const AllToAllSingleOptions &options) { + throw std::runtime_error("XCCL all_to_all_single is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommXCCL::all_to_all_v_single( + at::Tensor &output, const at::Tensor &input, + const std::vector &output_split_sizes, + const std::vector &input_split_sizes, bool async_op, + const AllToAllvSingleOptions &options) { + throw std::runtime_error("XCCL all_to_all_v_single is not supported now and will be added later"); +} + +c10::intrusive_ptr +TorchCommXCCL::all_to_all(const std::vector &output_tensor_list, + const std::vector &input_tensor_list, + bool async_op, const AllToAllOptions &options) { + throw std::runtime_error("XCCL all_to_all is not supported now and will be added later"); +} + +c10::intrusive_ptr +TorchCommXCCL::barrier(bool async_op, const BarrierOptions &options) { + throw std::runtime_error("XCCL barrier is not supported now and will be added later"); +} + +c10::intrusive_ptr +TorchCommXCCL::scatter(at::Tensor &output_tensor, + const std::vector &input_tensor_list, + int root, bool async_op, const ScatterOptions &options) { + throw std::runtime_error("XCCL scatter is not supported now and will be added later"); +} + +c10::intrusive_ptr +TorchCommXCCL::gather(const std::vector &output_tensor_list, + const at::Tensor &input_tensor, int root, bool async_op, + const GatherOptions &options) { + throw std::runtime_error("XCCL gather is not supported now and will be added later"); +} + +std::shared_ptr +TorchCommXCCL::split(const std::vector &ranks, const std::string &name, + const CommOptions &options) { + throw std::runtime_error("XCCL split is not supported now and will be added later"); +} + +XCCLException::XCCLException(XcclApi &xccl_api, const std::string &message, + onecclResult_t result) + : message_(message + ": " + xccl_api.getErrorString(result)), + result_(result) {} + +const char *XCCLException::what() const noexcept { return message_.c_str(); } + +} // namespace comms +} // namespace torch + +namespace { +class XCCLRegistration { +public: + XCCLRegistration() { + torch::comms::TorchCommFactory::get().register_backend("xccl", []() { + return std::make_shared(); + }); + } +}; + +static XCCLRegistration registration{}; +} // namespace diff --git a/comms/torchcomms/xccl/TorchCommXCCL.hpp b/comms/torchcomms/xccl/TorchCommXCCL.hpp new file mode 100644 index 00000000..a86a0d59 --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCL.hpp @@ -0,0 +1,264 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include // @manual=//caffe2:torch-cpp + +#include "comms/torchcomms/TorchComm.hpp" +#include "comms/torchcomms/TorchCommBackend.hpp" +#include "comms/torchcomms/TorchCommBatch.hpp" +#include "comms/torchcomms/TorchCommTracing.hpp" +#include "comms/torchcomms/device/XpuApi.hpp" +#include "comms/torchcomms/xccl/TorchWorkXCCL.hpp" +#include "comms/torchcomms/xccl/XcclApi.hpp" + +namespace torch { +namespace comms { + +constexpr size_t kMaxEventPoolSize = 1000; + +// Custom exception class for better error handling +class XCCLException : public std::exception { +public: + XCCLException(XcclApi &api, const std::string &message, + onecclResult_t result); + + const char *what() const noexcept override; + onecclResult_t getResult() const; + +private: + std::string message_; + onecclResult_t result_; +}; + +class TorchCommXCCL : public TorchCommBackend, + public std::enable_shared_from_this { +public: + static constexpr std::string_view kBackendName = "xccl"; + + TorchCommXCCL(); + ~TorchCommXCCL() override; + + // Delete copy and move operations + TorchCommXCCL(const TorchCommXCCL &) = delete; + TorchCommXCCL(TorchCommXCCL &&) = delete; + TorchCommXCCL &operator=(const TorchCommXCCL &) = delete; + TorchCommXCCL &operator=(TorchCommXCCL &&) = delete; + + void init(at::Device device, const std::string &name, + const CommOptions &options = {}) override; + void finalize() override; + int getRank() const override; + int getSize() const override; + std::string_view getBackendName() const override; + std::string_view getCommName() const override; + + // Point-to-Point Operations + c10::intrusive_ptr send(const at::Tensor &tensor, int dst, + bool async_op, + const SendOptions &options = {}) override; + c10::intrusive_ptr recv(at::Tensor &tensor, int src, bool async_op, + const RecvOptions &options = {}) override; + + // Batch P2P Operations + c10::intrusive_ptr + batch_op_issue(const std::vector &ops, bool async_op, + const BatchP2POptions &options = {}) override; + + // Collective Operations + c10::intrusive_ptr + broadcast(at::Tensor &tensor, int root, bool async_op, + const BroadcastOptions &options = {}) override; + c10::intrusive_ptr + all_reduce(at::Tensor &tensor, const ReduceOp &op, bool async_op, + const AllReduceOptions &options = {}) override; + c10::intrusive_ptr + reduce(const at::Tensor &tensor, int root, const ReduceOp &op, bool async_op, + const ReduceOptions &options = {}) override; + c10::intrusive_ptr + all_gather(const std::vector &tensor_list, + const at::Tensor &tensor, bool async_op, + const AllGatherOptions &options = {}) override; + c10::intrusive_ptr + all_gather_v(const std::vector &tensor_list, + const at::Tensor &tensor, bool async_op, + const AllGatherOptions &options = {}) override; + c10::intrusive_ptr + all_gather_single(at::Tensor &output, const at::Tensor &input, bool async_op, + const AllGatherSingleOptions &options = {}) override; + c10::intrusive_ptr + reduce_scatter(at::Tensor &output, const std::vector &input_list, + const ReduceOp &op, bool async_op, + const ReduceScatterOptions &options = {}) override; + c10::intrusive_ptr + reduce_scatter_v(at::Tensor &output, const std::vector &input_list, + const ReduceOp &op, bool async_op, + const ReduceScatterOptions &options = {}) override; + c10::intrusive_ptr reduce_scatter_single( + at::Tensor &output, const at::Tensor &input, const ReduceOp &op, bool async_op, + const ReduceScatterSingleOptions &options = {}) override; + c10::intrusive_ptr + all_to_all_single(at::Tensor &output, const at::Tensor &input, bool async_op, + const AllToAllSingleOptions &options = {}) override; + c10::intrusive_ptr + all_to_all_v_single(at::Tensor &output, const at::Tensor &input, + const std::vector &output_split_sizes, + const std::vector &input_split_sizes, + bool async_op, + const AllToAllvSingleOptions &options = {}) override; + c10::intrusive_ptr + all_to_all(const std::vector &output_tensor_list, + const std::vector &input_tensor_list, bool async_op, + const AllToAllOptions &options = {}) override; + c10::intrusive_ptr + barrier(bool async_op, const BarrierOptions &options = {}) override; + + // Scatter and Gather Operations + c10::intrusive_ptr + scatter(at::Tensor &output_tensor, + const std::vector &input_tensor_list, int root, + bool async_op, const ScatterOptions &options = {}) override; + c10::intrusive_ptr + gather(const std::vector &output_tensor_list, + const at::Tensor &input_tensor, int root, bool async_op, + const GatherOptions &options = {}) override; + + // Communicator Management + std::shared_ptr + split(const std::vector &ranks, const std::string &name, + const CommOptions &options = {}) override; + + // Friend access for TorchCommXCCL + friend class TorchWorkXCCL; + + // Getter for XPU API (for friend classes) + XpuApi *getXpuApi() const { return xpu_api_.get(); } + + // Getter for XCCL API (for friend classes) + XcclApi *getXcclApi() const { return xccl_api_.get(); } + + // Method to override the XCCL API implementation for testing + void setXcclApi(std::shared_ptr api) { + xccl_api_ = std::move(api); + } + + // Method to override the XPU API implementation for testing + void setXpuApi(std::shared_ptr api) { xpu_api_ = std::move(api); } + + const CommOptions &getOptions() const override { return options_; } + + const at::Device &getDevice() const override { return device_; } + +protected: + // Event management for friend classes + xpuEvent_t getEvent(); + void returnEvent(xpuEvent_t &&event); + void abortXcclComm(); + + enum class CommState { + NORMAL, + ERROR, + TIMEOUT, + }; + + std::atomic comm_state_{ + CommState::NORMAL}; // State of the communicator + + onecclDataType_t getXcclDataType(const at::Tensor &tensor); + c10::intrusive_ptr + createWork(xpuStream_t stream, std::chrono::milliseconds timeout, + const std::vector &inputTensors); + +private: + // Helper that automatically cleans up premul sums. + struct RedOpRAII { + /* implicit */ RedOpRAII(onecclRedOp_t op); + + // Constructor for Premulsum Reduction + explicit RedOpRAII(const ReduceOp &op, const onecclComm_t comm, + const onecclDataType_t dataType, + std::shared_ptr xccl_api); + + RedOpRAII() = delete; + RedOpRAII(const RedOpRAII &) = delete; + RedOpRAII &operator=(const RedOpRAII &) = delete; + RedOpRAII(RedOpRAII &&tmp) = delete; + RedOpRAII &operator=(RedOpRAII &&) = delete; + ~RedOpRAII(); + + operator onecclRedOp_t() const { return xcclRedOp_; } + + onecclRedOp_t xcclRedOp_{onecclMaxRedOp}; + onecclComm_t comm_{nullptr}; + std::shared_ptr xccl_api_; + }; + + // Constructor for split communicators + explicit TorchCommXCCL(const onecclComm_t xccl_comm); + + // Private utility methods + size_t wordSize(onecclDataType_t type) const; + RedOpRAII getXcclReduceOp(const ReduceOp &op, const onecclComm_t comm, + const onecclDataType_t dataType); + void timeoutWatchdog() noexcept; + void checkInitialized() const; + void checkAndAbortIfTimedOutOrError(); + void checkWorkQueue(bool isMainThread); + void enqueueWork(c10::intrusive_ptr work, xpuStream_t stream); + xpuStream_t getOperationStream(bool async_op); + void ensureTensorContiguous(const at::Tensor &tensor); + + // Member variables + onecclComm_t xccl_comm_{}; + at::Device device_; + int comm_size_{}; + int rank_{}; + CommOptions options_; + size_t max_event_pool_size_{}; + std::optional internal_stream_; // Initialized in init() + std::optional + dependency_event_; // Pre-allocated event for stream dependencies + void *barrier_buffer_{}; // Pre-allocated XPU buffer for barrier operations + enum class InitializationState { + UNINITIALIZED, + INITIALIZED, + FINALIZED, + } init_state_; + + // XCCL API abstraction + std::shared_ptr xccl_api_; + + // XPU API abstraction + std::shared_ptr xpu_api_; + + // Event pool management + std::queue event_pool_; + std::mutex event_pool_mutex_; + + // Work tracking per stream + TorchWorkXCCLQueue workq_; + + // Timeout monitoring + std::thread timeout_thread_; + std::atomic shutdown_; + std::condition_variable timeout_cv_; + std::mutex timeout_mutex_; + + std::shared_ptr tracing_; + bool high_priority_stream_{false}; + std::string name_; +}; + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp new file mode 100644 index 00000000..14b91015 --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.cpp @@ -0,0 +1,256 @@ +#include "comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp" +#include "comms/torchcomms/StoreManager.hpp" +#include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/TorchCommUtils.hpp" +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" +#include +#include // @manual + +namespace torch { +namespace comms { + +// Initialize the static counter +int TorchCommXCCLBootstrap::counter_ = 0; + +const std::string kUniqueidXchgMethodAuto = "auto"; +const std::string kUniqueidXchgMethodTCPStore = "tcpstore"; +const std::string kUniqueidXchgMethodDefault = kUniqueidXchgMethodAuto; + +TorchCommXCCLBootstrap::TorchCommXCCLBootstrap( + c10::intrusive_ptr store, c10::Device device, + std::shared_ptr xccl_api, std::shared_ptr xpu_api, + std::chrono::milliseconds timeout) + : timeout_(timeout), store_(store), created_internal_store_(false), + device_(device), xccl_api_(xccl_api), xpu_api_(xpu_api) { + // Query rank and size using the utility function + auto ranksize = query_ranksize(); + rank_ = ranksize.first; + comm_size_ = ranksize.second; + + const char *uniqueid_xchg_env = + std::getenv("TORCHCOMM_XCCL_BOOTSTRAP_UNIQUEID_EXCHANGE_METHOD"); + if (uniqueid_xchg_env == nullptr) { + TC_LOG(INFO) + << "TORCHCOMM_XCCL_BOOTSTRAP_UNIQUEID_EXCHANGE_METHOD not set, " + << "defaulting to " << kUniqueidXchgMethodDefault; + uniqueid_xchg_method_ = kUniqueidXchgMethodDefault; + } else { + uniqueid_xchg_method_ = uniqueid_xchg_env; + } + std::transform(uniqueid_xchg_method_.begin(), uniqueid_xchg_method_.end(), + uniqueid_xchg_method_.begin(), + [](unsigned char c) { return std::tolower(c); }); + + if (device_.index() == -1) { + int device_count; + XPU_CHECK(xpu_api_, xpu_api_->getDeviceCount(&device_count), + "Failed to get XPU device count"); + + device_ = c10::Device(c10::kXPU, rank_ % device_count); + TC_LOG(INFO) << "User did not provide device ID; using device xpu:" + << static_cast(device_.index()); + } + + XPU_CHECK(xpu_api_, xpu_api_->setDevice(device_.index()), + "Failed to set device to " + std::to_string(device_.index())); + + // Allocate XPU memory for a single float32 value used in barrier operations + XPU_CHECK(xpu_api_, xpu_api_->malloc(&barrier_buffer_, sizeof(float)), + "Failed to allocate barrier buffer"); +} + +TorchCommXCCLBootstrap::~TorchCommXCCLBootstrap() { + if (barrier_buffer_ != nullptr) { + XPU_CHECK(xpu_api_, xpu_api_->free(barrier_buffer_), + "Failed to free barrier buffer"); + barrier_buffer_ = nullptr; + } +} + +std::string TorchCommXCCLBootstrap::getXCCLStoreKey() { + std::string key = getXCCLStoreKeyPrefix() + std::to_string(counter_); + counter_++; + return key; +} + +std::string TorchCommXCCLBootstrap::getXCCLStoreKeyPrefix() { + return "xccl_storekey_"; +}; + +int TorchCommXCCLBootstrap::getXCCLStoreKeyCounter() { return counter_; } + +onecclUniqueId TorchCommXCCLBootstrap::exchangeUniqueIdStore() { + onecclUniqueId uniqueId; + + auto key = getXCCLStoreKey(); + + if (rank_ == 0) { + // Generate unique ID on rank 0 + onecclResult_t xcclErr = xccl_api_->getUniqueId(&uniqueId); + + if (xcclErr != onecclSuccess) { + throw std::runtime_error("Failed to get XCCL unique ID: " + + std::string(xccl_api_->getErrorString(xcclErr))); + } + + // Set the unique ID in the store + std::vector vec(reinterpret_cast(&uniqueId), + reinterpret_cast(&uniqueId) + + sizeof(uniqueId)); + store_->set(key, vec); + } else { + // Other ranks read the broadcast ID + auto vec = store_->get(key); + + if (vec.size() != sizeof(onecclUniqueId)) { + throw std::runtime_error("Invalid XCCL unique ID size"); + } + uniqueId = *(reinterpret_cast(vec.data())); + } + + return uniqueId; +} + +onecclUniqueId +TorchCommXCCLBootstrap::exchangeUniqueIdTCPStore(std::string_view name) { + store_ = + StoreManager::get().getStore(TorchCommXCCL::kBackendName, name, timeout_); + created_internal_store_ = true; + + return exchangeUniqueIdStore(); +} + +bool TorchCommXCCLBootstrap::isTCPStoreEnabled() { + return std::getenv("MASTER_ADDR") && std::getenv("MASTER_PORT"); +} + +onecclUniqueId TorchCommXCCLBootstrap::exchangeUniqueId(std::string_view name) { + if (store_ != nullptr) { + return exchangeUniqueIdStore(); + } + + bool is_tcp_store_enabled = isTCPStoreEnabled(); + if (uniqueid_xchg_method_ != kUniqueidXchgMethodAuto && + uniqueid_xchg_method_ != kUniqueidXchgMethodTCPStore) { + throw std::runtime_error("Invalid unique ID exchange method " + + uniqueid_xchg_method_); + } + if (!is_tcp_store_enabled) { + throw std::runtime_error("No way to exchange unique ID"); + } + return exchangeUniqueIdTCPStore(name); +} + +void TorchCommXCCLBootstrap::cleanupTCPStore(onecclComm_t xccl_comm) { + if (created_internal_store_) { + // Delete the internal store object and do a barrier to ensure that all + // processes have deleted their store object too. This way, when we + // create the next torchcomm, we can use the same port to create a new store + // object. + store_.reset(); + + auto stream = xpu_api_->getCurrentXPUStream(device_.index()); + onecclResult_t result = + xccl_api_->allReduce(barrier_buffer_, barrier_buffer_, 1, onecclFloat32, + onecclSum, xccl_comm, stream); + if (result != onecclSuccess) { + TC_LOG(ERROR) << "XCCL AllReduce failed: " + << xccl_api_->getErrorString(result); + } + + XPU_CHECK(xpu_api_, xpu_api_->streamSynchronize(stream), + "Stream synchronization failed"); + } +} + +// Helper function to populate XCCL config from hints +void populateXcclConfigFromHints(onecclConfig_t &config, + const CommOptions &options, + const std::string &name) { + // Iterate over the hints and set the corresponding fields in the config. For + // string arguments, XCCL uses a "const char*" instead of a std::string, so + // it is hard to figure out the ownership structure. Here, we create a copy + // of the string and pass it to XCCL, so that it is responsible for freeing + // it. + + for (const auto &[key, val] : options.hints) { + if (key == "blocking") { + config.blocking = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.blocking=" << config.blocking; + } else if (key == "cgaClusterSize" || key == "cga_cluster_size") { + config.cgaClusterSize = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name << "] Setting config.cgaClusterSize=" + << config.cgaClusterSize; + } else if (key == "minCTAs" || key == "min_ctas") { + config.minCTAs = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.minCTAs=" << config.minCTAs; + } else if (key == "maxCTAs" || key == "max_ctas") { + config.maxCTAs = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.maxCTAs=" << config.maxCTAs; + } else if (key == "netName") { + config.netName = strdup(val.c_str()); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.netName=" << config.netName; + } else if (key == "splitShare" || key == "split_share") { + config.splitShare = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name + << "] Setting config.splitShare=" << config.splitShare; + } else if (key == "trafficClass" || key == "traffic_class" || + key == "commName" || key == "collnetEnable" || + key == "collnet_enable" || key == "CTAPolicy" || + key == "cta_policy" || key == "shrinkShare" || + key == "nvlsCTAs" || key == "nvls_ctas" || + key == "nChannelsPerNetPeer" || + key == "n_channels_per_net_peer" || + key == "nvlinkCentricSched" || key == "nvlink_centric_sched") { + TC_LOG(WARNING) << "XCCL hint '" << key + << "' is NCCL-specific and not supported by oneCCL, " + "ignoring for comm '" + << name << "'"; + } else { + TC_LOG(WARNING) + << "XCCL hint '" << key + << "' is not supported in this XCCL version, ignoring for comm '" + << name << "'"; + } + } +} + +onecclComm_t +TorchCommXCCLBootstrap::createXcclComm(const std::string &name, + const CommOptions &options) { + onecclUniqueId uniqueId; + onecclComm_t xccl_comm = nullptr; + + uniqueId = exchangeUniqueId(name); + + onecclConfig_t config = ONECCL_CONFIG_INITIALIZER; + + // Populate XCCL config from user-provided hints + populateXcclConfigFromHints(config, options, name); + + // Set device for oneCCL before initializing communicator + onecclResult_t xcclErr = xccl_api_->setDevice(device_.index()); + if (xcclErr != onecclSuccess) { + throw std::runtime_error("Failed to set oneCCL device: " + + std::string(xccl_api_->getErrorString(xcclErr))); + } + + xcclErr = xccl_api_->commInitRankConfig(&xccl_comm, comm_size_, uniqueId, + rank_, &config); + + if (xcclErr != onecclSuccess || xccl_comm == nullptr) { + throw std::runtime_error("Failed to initialize XCCL communicator: " + + std::string(xccl_api_->getErrorString(xcclErr))); + } + + cleanupTCPStore(xccl_comm); + + return xccl_comm; +} + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp new file mode 100644 index 00000000..ea677d4f --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCLBootstrap.hpp @@ -0,0 +1,73 @@ +#pragma once + +#include + +#include +#include // @manual=//caffe2:torch-cpp + +#include "comms/torchcomms/TorchCommOptions.hpp" +#include "comms/torchcomms/device/XpuApi.hpp" +#include "comms/torchcomms/xccl/XcclApi.hpp" +#include +#include + +namespace torch { +namespace comms { + +constexpr uint16_t kTCPStorePort = 29500; + +class TorchCommXCCLBootstrap { +public: + TorchCommXCCLBootstrap(c10::intrusive_ptr store, + c10::Device device, + std::shared_ptr xccl_api, + std::shared_ptr xpu_api, + std::chrono::milliseconds timeout); + ~TorchCommXCCLBootstrap(); + + // Delete copy and move operations + TorchCommXCCLBootstrap(const TorchCommXCCLBootstrap &) = delete; + TorchCommXCCLBootstrap &operator=(const TorchCommXCCLBootstrap &) = delete; + TorchCommXCCLBootstrap(TorchCommXCCLBootstrap &&) = delete; + TorchCommXCCLBootstrap &operator=(TorchCommXCCLBootstrap &&) = delete; + + onecclComm_t createXcclComm(const std::string &name, + const CommOptions &options = {}); + static std::string getXCCLStoreKey(); + static std::string getXCCLStoreKeyPrefix(); + static int getXCCLStoreKeyCounter(); + + int getRank() { return rank_; } + int getSize() { return comm_size_; } + c10::Device getDevice() { return device_; } + +private: + onecclUniqueId exchangeUniqueId(std::string_view name); + onecclUniqueId exchangeUniqueIdStore(); + onecclUniqueId exchangeUniqueIdTCPStore(std::string_view name); + bool isTCPStoreEnabled(); + void cleanupTCPStore(onecclComm_t xccl_comm); + +private: + const std::chrono::milliseconds timeout_; + static int counter_; + + c10::intrusive_ptr store_; + bool created_internal_store_; + c10::Device device_; + std::shared_ptr xccl_api_; + std::shared_ptr xpu_api_; + void *barrier_buffer_{nullptr}; + int rank_; + int comm_size_; + + std::string uniqueid_xchg_method_; +}; + +// Helper function to populate XCCL config from hints +void populateXcclConfigFromHints(onecclConfig_t &config, + const CommOptions &options, + const std::string &name); + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchCommXCCLPy.cpp b/comms/torchcomms/xccl/TorchCommXCCLPy.cpp new file mode 100644 index 00000000..aa9f0829 --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCLPy.cpp @@ -0,0 +1,16 @@ +#include +#include +#include +#include +#include + +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" + +namespace py = pybind11; +using namespace torch::comms; + +PYBIND11_MODULE(_comms_xccl, m) { + m.doc() = "XCCL specific python bindings for TorchComm"; + + py::class_>(m, "TorchCommXCCL"); +} diff --git a/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp new file mode 100644 index 00000000..da754fb2 --- /dev/null +++ b/comms/torchcomms/xccl/TorchCommXCCLUtils.cpp @@ -0,0 +1,314 @@ +#include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" +#include +#include +#include +#include + +namespace torch { +namespace comms { + +namespace { + +onecclDataType_t getXcclDataTypeInternal(const at::Tensor &tensor) { + switch (tensor.scalar_type()) { + case at::ScalarType::Float: + return onecclFloat32; + case at::ScalarType::Double: + return onecclFloat64; + case at::ScalarType::Half: + return onecclFloat16; + case at::ScalarType::BFloat16: + return onecclBfloat16; + case at::ScalarType::Int: + return onecclInt32; + case at::ScalarType::Long: + return onecclInt64; + case at::ScalarType::Char: + return onecclInt8; + case at::ScalarType::Byte: + return onecclUint8; + default: + throw std::runtime_error("Unsupported tensor data type for XCCL"); + } +} + +template +void createPreMulSum(onecclRedOp_t *op, const PreMulSumFactorT &factor, + const onecclComm_t &comm, XcclApi *xccl_api) { + const bool is_tensor = std::holds_alternative(factor); + const auto residence = + is_tensor ? onecclScalarDevice : onecclScalarHostImmediate; + + at::Tensor tensor = is_tensor ? std::get(factor) : at::Tensor(); + T scalar_factor = is_tensor ? T{} : static_cast(std::get(factor)); + void *scalar = is_tensor ? tensor.data_ptr() : &scalar_factor; + + TORCH_INTERNAL_ASSERT(is_tensor ? dataType == getXcclDataTypeInternal(tensor) + : dataType != onecclBfloat16, + "PreMulSum factor type must match input data type"); + xccl_api->redOpCreatePreMulSum(op, scalar, dataType, residence, comm); +} + +} // namespace + +TorchCommXCCL::RedOpRAII::RedOpRAII(onecclRedOp_t op) + : xcclRedOp_(op), comm_(nullptr) {} + +TorchCommXCCL::RedOpRAII::RedOpRAII(const ReduceOp &op, const onecclComm_t comm, + const onecclDataType_t dataType, + std::shared_ptr xccl_api) + : comm_(comm), xccl_api_(std::move(xccl_api)) { + TORCH_INTERNAL_ASSERT( + op == ReduceOp::RedOpType::PREMUL_SUM, + "Constructing premul_sum RedOpRAII with non-premul_sum RedOpType"); + + if (!op.factor().has_value()) { + xcclRedOp_ = onecclSum; + comm_ = nullptr; + return; + } + + const auto &factor = op.factor().value(); + switch (dataType) { + case onecclFloat16: + createPreMulSum(&xcclRedOp_, factor, comm, + xccl_api_.get()); + break; + case onecclFloat32: + createPreMulSum(&xcclRedOp_, factor, comm, + xccl_api_.get()); + break; + case onecclBfloat16: + createPreMulSum(&xcclRedOp_, factor, comm, + xccl_api_.get()); + break; + case onecclFloat64: + createPreMulSum(&xcclRedOp_, factor, comm, + xccl_api_.get()); + break; + default: + throw std::runtime_error( + "PreMulSum Data type must be half, float, bfloat16 or double"); + } +} + +TorchCommXCCL::RedOpRAII::~RedOpRAII() { + if (comm_) { + xccl_api_->redOpDestroy(xcclRedOp_, comm_); + } +} + +size_t TorchCommXCCL::wordSize(onecclDataType_t type) const { + switch (type) { + case onecclInt8: + case onecclUint8: + return 1; + case onecclFloat16: + case onecclBfloat16: + return 2; + case onecclInt32: + case onecclUint32: + case onecclFloat32: + return 4; + case onecclInt64: + case onecclUint64: + case onecclFloat64: + return 8; + default: + return 0; + } +} + +onecclDataType_t TorchCommXCCL::getXcclDataType(const at::Tensor &tensor) { + return getXcclDataTypeInternal(tensor); +} + +TorchCommXCCL::RedOpRAII +TorchCommXCCL::getXcclReduceOp(const ReduceOp &op, const onecclComm_t comm, + const onecclDataType_t dataType) { + switch (op) { + case ReduceOp::RedOpType::SUM: + return onecclSum; + case ReduceOp::RedOpType::PRODUCT: + return onecclProd; + case ReduceOp::RedOpType::MIN: + return onecclMin; + case ReduceOp::RedOpType::MAX: + return onecclMax; + case ReduceOp::RedOpType::BAND: + return onecclSum; // XCCL doesn't have bitwise AND, using SUM as fallback + case ReduceOp::RedOpType::BOR: + return onecclSum; // XCCL doesn't have bitwise OR, using SUM as fallback + case ReduceOp::RedOpType::BXOR: + return onecclSum; // XCCL doesn't have bitwise XOR, using SUM as fallback + case ReduceOp::RedOpType::PREMUL_SUM: + return RedOpRAII(op, comm, dataType, xccl_api_); + case ReduceOp::RedOpType::AVG: + return onecclAvg; + default: + throw std::runtime_error("Unsupported reduce operation"); + } +} + +void TorchCommXCCL::checkWorkQueue(bool isMainThread) { + TorchWorkXCCL::WorkStatus status = workq_.garbageCollect(isMainThread); + + switch (status) { + case TorchWorkXCCL::WorkStatus::TIMEDOUT: + comm_state_ = CommState::TIMEOUT; + break; + case TorchWorkXCCL::WorkStatus::ERROR: + comm_state_ = CommState::ERROR; + break; + default: + // For COMPLETED, NOT_STARTED, and INPROGRESS, no state change needed + break; + } +} + +// The timeout thread cannot make XCCL calls. The only XPU call it can make +// it xpuEventQuery. +void TorchCommXCCL::timeoutWatchdog() noexcept { + TC_LOG(INFO) << "Timeout thread starting for rank: " << rank_; + while (!shutdown_) { + { + std::unique_lock lock(timeout_mutex_); + // Wait for a shorter interval to check work objects periodically + // Wake up either after 1 second or immediately if shutdown is requested + timeout_cv_.wait_for(lock, std::chrono::seconds(1), + [this]() { return shutdown_.load(); }); + + // If we're shutting down, exit the loop + if (shutdown_) { + break; + } + } + + // Check work objects for completion or timeout + checkWorkQueue(false); + if (comm_state_ != CommState::NORMAL && + options_.abort_process_on_timeout_or_error) { + // Log the error and abort the process. We cannot abort the XCCL + // communicator as it is not safe to call XCCL operations from + // multiple threads at the same time. + if (comm_state_ == CommState::TIMEOUT) { + TC_LOG(ERROR) << "Aborting process due to timeout on rank " << rank_ + << " - timeout watchdog detected operation timeout"; + } else if (comm_state_ == CommState::ERROR) { + TC_LOG(ERROR) << "Aborting process due to error on rank " << rank_ + << " - timeout watchdog detected operation error. "; + } + abort(); + } + } + + TC_LOG(INFO) << "Timeout thread exiting for rank: " << rank_; +} + +void TorchCommXCCL::checkInitialized() const { + if (init_state_ != InitializationState::INITIALIZED) { + throw std::runtime_error("TorchCommXCCL not initialized"); + } +} + +void TorchCommXCCL::checkAndAbortIfTimedOutOrError() { + // First, check work queue status + checkWorkQueue(true); + + if (comm_state_ == CommState::TIMEOUT) { + // abortXcclComm(); // cannot abort oneCCL communicator + if (options_.abort_process_on_timeout_or_error) { + TC_LOG(ERROR) << "Aborting process due to timeout"; + abort(); + } else { + throw std::runtime_error("XCCL operation timed out"); + } + } else if (comm_state_ == CommState::ERROR) { + onecclResult_t asyncErr; + xccl_api_->commGetAsyncError(xccl_comm_, &asyncErr); + XCCLException xcclException(*xccl_api_, "XCCL Async Error", asyncErr); + // abortXcclComm(); // cannot abort oneCCL communicator + if (options_.abort_process_on_timeout_or_error) { + TC_LOG(ERROR) << "Aborting process due to error: " + << xcclException.what(); + abort(); + } else { + throw xcclException; + } + } +} + +c10::intrusive_ptr +TorchCommXCCL::createWork(xpuStream_t stream, std::chrono::milliseconds timeout, + const std::vector &inputTensors) { + // Only create the work object without enqueuing it + auto work = c10::make_intrusive(shared_from_this(), stream, + timeout, inputTensors, tracing_); + return work; +} + +void TorchCommXCCL::enqueueWork(c10::intrusive_ptr work, + xpuStream_t stream) { + // Add work to stream's queue after events have been recorded + workq_.enqueueWork(std::move(work), stream); +} + +xpuStream_t TorchCommXCCL::getOperationStream(bool async_op) { + if (async_op) { + // Get current PyTorch XPU stream for this device + xpuStream_t current_stream = xpu_api_->getCurrentXPUStream(device_.index()); + + // Record event on current stream and wait for it on internal stream + XPU_CHECK(xpu_api_, + xpu_api_->eventRecord(dependency_event_.value(), current_stream), + "Failed to record dependency event"); + + XPU_CHECK(xpu_api_, + xpu_api_->streamWaitEvent(internal_stream_.value(), + dependency_event_.value(), 0), + "Failed to make internal stream wait for dependency event"); + + return internal_stream_.value(); + } else { + // Use the current PyTorch XPU stream for synchronous operations + return xpu_api_->getCurrentXPUStream(device_.index()); + } +} + +void TorchCommXCCL::ensureTensorContiguous(const at::Tensor &tensor) { + if (!tensor.is_contiguous()) { + throw std::runtime_error("Tensor must be contiguous for XCCL operations"); + } +} + +// Protected methods (not in the private section of the header) +xpuEvent_t TorchCommXCCL::getEvent() { + std::lock_guard lock(event_pool_mutex_); + + if (!event_pool_.empty()) { + xpuEvent_t event = std::move(event_pool_.front()); + event_pool_.pop(); + return event; + } + + // Create new event if pool is empty + xpuEvent_t event; + XPU_CHECK(xpu_api_, xpu_api_->eventCreateWithFlags(event, /*flags=*/0), + "Failed to create event"); + return event; +} + +void TorchCommXCCL::returnEvent(xpuEvent_t &&event) { + std::lock_guard lock(event_pool_mutex_); + + if (event_pool_.size() < max_event_pool_size_) { + event_pool_.push(std::move(event)); + } else { + // Pool is full, destroy the event + XPU_CHECK(xpu_api_, xpu_api_->eventDestroy(event), + "Failed to destroy event"); + } +} +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchWorkXCCL.cpp b/comms/torchcomms/xccl/TorchWorkXCCL.cpp new file mode 100644 index 00000000..1d40674b --- /dev/null +++ b/comms/torchcomms/xccl/TorchWorkXCCL.cpp @@ -0,0 +1,130 @@ +#include "comms/torchcomms/xccl/TorchWorkXCCL.hpp" +#include "comms/torchcomms/TorchCommLogging.hpp" +#include "comms/torchcomms/xccl/TorchCommXCCL.hpp" +#include + +namespace torch { +namespace comms { + +TorchWorkXCCL::TorchWorkXCCL(std::shared_ptr comm, + xpuStream_t stream, + std::chrono::milliseconds timeout_ms, + const std::vector &inputTensors, + std::shared_ptr tracing) + : inputTensors_(inputTensors), comm_(std::move(comm)), stream_(stream), + timeout_ms_(timeout_ms), state_(WorkStatus::NOT_STARTED), + tracing_(std::move(tracing)) { + // If not in graph capture mode, create the events for start and end + // recording + start_event_ = comm_->getEvent(); + end_event_ = comm_->getEvent(); + + // Events will be recorded around the actual XCCL operations +} + +TorchWorkXCCL::~TorchWorkXCCL() { + if (!comm_) { + return; + } + // If not in graph capture mode, return the events to the pool + comm_->returnEvent(std::move(start_event_)); + comm_->returnEvent(std::move(end_event_)); +} + +void TorchWorkXCCL::recordStart() { + XPU_CHECK(comm_->getXpuApi(), + comm_->getXpuApi()->eventRecord(start_event_, stream_), + "Failed to record start event"); +} + +void TorchWorkXCCL::recordEnd() { + XPU_CHECK(comm_->getXpuApi(), + comm_->getXpuApi()->eventRecord(end_event_, stream_), + "Failed to record end event"); +} + +bool TorchWorkXCCL::isCompleted() { return state_ == WorkStatus::COMPLETED; } + +TorchWorkXCCL::WorkStatus TorchWorkXCCL::checkStatus() { + // If already marked as completed, return COMPLETED + if (state_ == WorkStatus::COMPLETED || state_ == WorkStatus::ERROR || + state_ == WorkStatus::TIMEDOUT) { + return state_; + } + + // Step 1: If start_completed_time_ doesn't have a value yet, query the start + // event + if (!start_completed_time_.has_value()) { + xpu_result_t start_status = comm_->getXpuApi()->eventQuery(start_event_); + + if (start_status == XPU_SUCCESS) { + // Start event has completed, store the current time + start_completed_time_ = std::chrono::steady_clock::now(); + state_ = WorkStatus::INPROGRESS; + } else if (start_status != XPU_ERROR_NOT_READY && + start_status != XPU_ERROR_UNSUPPORTED) { + // Some other error occurred with the start event + TC_LOG(ERROR) << "XPU error during start event query: " + << comm_->getXpuApi()->getErrorString(start_status) << " (" + << start_status << ")"; + state_ = WorkStatus::ERROR; + } + } + if (state_ == WorkStatus::NOT_STARTED || state_ == WorkStatus::ERROR) { + return state_; + } + + // Step 2: If we get here, start event has completed, so query the end event + xpu_result_t end_status = comm_->getXpuApi()->eventQuery(end_event_); + + if (end_status == XPU_SUCCESS) { + // End event has completed, mark the work as completed + state_ = WorkStatus::COMPLETED; + + // Release the input tensors to keep the lifetime of the tensors short + inputTensors_.clear(); + } else if (end_status == XPU_ERROR_NOT_READY) { + // End event has not completed yet, check for timeout + auto current_time = std::chrono::steady_clock::now(); + auto elapsed_milliseconds = + std::chrono::duration_cast( + current_time - start_completed_time_.value()); + + // Check if the operation has timed out + if (elapsed_milliseconds > timeout_ms_) { + // Operation has timed out + state_ = WorkStatus::TIMEDOUT; + } + } else if (end_status != XPU_ERROR_UNSUPPORTED) { + // Some other error occurred with the end event + TC_LOG(ERROR) << "XPU error during end event query: " + << comm_->getXpuApi()->getErrorString(end_status) << " (" + << end_status << ")"; + state_ = WorkStatus::ERROR; + } + return state_; +} + +void TorchWorkXCCL::wait() { + // If already completed, return immediately + WorkStatus local_state = state_; + if (local_state == WorkStatus::COMPLETED || + local_state == WorkStatus::ERROR || local_state == WorkStatus::TIMEDOUT) { + return; + } + + tracing_->recordEvent("wait"); + + // Get the current stream using the device from the comm object + xpuStream_t current_stream = + comm_->getXpuApi()->getCurrentXPUStream(comm_->device_.index()); + + // Add a dependency from the work's stream to the current stream + // This makes the current stream wait for the end_event_ recorded on the + // work's stream + XPU_CHECK(comm_->getXpuApi(), + comm_->getXpuApi()->streamWaitEvent(current_stream, end_event_, 0), + "Failed to make stream wait for event"); +} +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchWorkXCCL.hpp b/comms/torchcomms/xccl/TorchWorkXCCL.hpp new file mode 100644 index 00000000..d81e557e --- /dev/null +++ b/comms/torchcomms/xccl/TorchWorkXCCL.hpp @@ -0,0 +1,95 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "comms/torchcomms/TorchCommTracing.hpp" +#include "comms/torchcomms/TorchWork.hpp" +#include "comms/torchcomms/device/XpuApi.hpp" +#include + +namespace torch { +namespace comms { + +// Forward declaration +class TorchCommXCCL; + +class TorchWorkXCCL : public TorchWork { +public: + // Status of a work object + enum class WorkStatus { + NOT_STARTED, // Work has not started yet + INPROGRESS, // Work is still in progress, + COMPLETED, // Work has completed successfully + TIMEDOUT, // Work has timed out + ERROR // Work has encountered an error + }; + + TorchWorkXCCL(std::shared_ptr comm, xpuStream_t stream, + std::chrono::milliseconds timeout_ms, + const std::vector &inputTensors, + std::shared_ptr tracing); + ~TorchWorkXCCL() override; + + // Delete copy and move operations + TorchWorkXCCL(const TorchWorkXCCL &) = delete; + TorchWorkXCCL(TorchWorkXCCL &&) = delete; + TorchWorkXCCL &operator=(const TorchWorkXCCL &) = delete; + TorchWorkXCCL &operator=(TorchWorkXCCL &&) = delete; + + // Override virtual functions from TorchWork + bool isCompleted() override; + void wait() override; + +protected: + void recordStart(); + void recordEnd(); + + friend class TorchCommXCCL; + friend class TorchWorkXCCLQueue; + +private: + // Check the status of the work object + WorkStatus checkStatus(); + + std::chrono::milliseconds getTimeout() const { return timeout_ms_; } + std::vector inputTensors_; + + std::shared_ptr comm_; + xpuEvent_t start_event_; + xpuEvent_t end_event_; + xpuStream_t stream_; // stream is not owned by this class + + std::chrono::milliseconds timeout_ms_; + + // state machine variables. TODO: convert to state machine later + std::atomic state_; + + std::optional start_completed_time_; + std::shared_ptr tracing_; +}; + +class TorchWorkXCCLQueue { +public: + TorchWorkXCCLQueue() = default; + ~TorchWorkXCCLQueue() = default; + + TorchWorkXCCL::WorkStatus garbageCollect(bool isMainThread); + // Finalize function can only be called from the main thread + TorchWorkXCCL::WorkStatus finalize(); + void enqueueWork(c10::intrusive_ptr work, xpuStream_t stream); + +private: + std::unordered_map>> + stream_work_queues_; + std::vector> completed_work_queue_; + std::recursive_mutex work_queues_mutex_; +}; + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp b/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp new file mode 100644 index 00000000..e4edecbc --- /dev/null +++ b/comms/torchcomms/xccl/TorchWorkXCCLQueue.cpp @@ -0,0 +1,96 @@ +#include "comms/torchcomms/xccl/TorchWorkXCCL.hpp" + +namespace torch { +namespace comms { + +TorchWorkXCCL::WorkStatus +TorchWorkXCCLQueue::garbageCollect(bool isMainThread) { + std::lock_guard lock(work_queues_mutex_); + + TorchWorkXCCL::WorkStatus last_status = TorchWorkXCCL::WorkStatus::COMPLETED; + + // Keep popping completed elements until we hit an in-progress element + // or the queue is empty + // Use an iterator to safely remove empty queues while iterating + auto it = stream_work_queues_.begin(); + while (it != stream_work_queues_.end()) { + auto &work_queue = it->second; + + while (!work_queue.empty()) { + // Get the first work object in the queue + auto work = work_queue.front(); + + // Use the checkStatus function to determine the work status + TorchWorkXCCL::WorkStatus status = work->checkStatus(); + last_status = status; + + if (status == TorchWorkXCCL::WorkStatus::COMPLETED) { + // Work is completed, remove it from the work queue + work_queue.pop(); + completed_work_queue_.push_back(work); + // Continue to the next element in the queue + } else if (status == TorchWorkXCCL::WorkStatus::TIMEDOUT || + status == TorchWorkXCCL::WorkStatus::ERROR) { + // Return the error status immediately + return status; + } else { + // NOT_STARTED or INPROGRESS - stop processing this queue + break; + } + } + + // If the queue is now empty, remove it from the map + if (work_queue.empty()) { + it = stream_work_queues_.erase(it); + } else { + ++it; + } + } + + if (isMainThread) { + // If we are the main thread, clear the completed work queues + completed_work_queue_.clear(); + } + + return last_status; +} + +TorchWorkXCCL::WorkStatus TorchWorkXCCLQueue::finalize() { + // Because this function is typically called after the timeout thread has + // already joined, we might not need to lock here. But doing the lock anyway, + // as defensive programming, just in case someone moves the thread join order + // later. The cost of the lock itself should be small on modern linux systems + // (uncontended locks are typically just an atomic operation). + std::lock_guard lock(work_queues_mutex_); + + // Initialize the status to COMPLETED to cover the case where the queue is + // empty + TorchWorkXCCL::WorkStatus status = TorchWorkXCCL::WorkStatus::COMPLETED; + while (!stream_work_queues_.empty()) { + status = garbageCollect(true); + if (status == TorchWorkXCCL::WorkStatus::ERROR || + status == TorchWorkXCCL::WorkStatus::TIMEDOUT || + status == TorchWorkXCCL::WorkStatus::COMPLETED) { + break; + } + } + + // Clear all work queues & completed work queue. + // + // NOTE: finalize MUST return without holding references to any work object, + // otherwise it may leak object and cause side effects. + stream_work_queues_.clear(); + completed_work_queue_.clear(); + + return status; +} + +void TorchWorkXCCLQueue::enqueueWork(c10::intrusive_ptr work, + xpuStream_t stream) { + // Add work to stream's queue after events have been recorded + std::lock_guard lock(work_queues_mutex_); + stream_work_queues_[stream].push(work); +} + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/XcclApi.cpp b/comms/torchcomms/xccl/XcclApi.cpp new file mode 100644 index 00000000..ec0c1f7b --- /dev/null +++ b/comms/torchcomms/xccl/XcclApi.cpp @@ -0,0 +1,155 @@ +#include "comms/torchcomms/xccl/XcclApi.hpp" +#include "comms/torchcomms/TorchCommLogging.hpp" + +namespace torch { +namespace comms { + +const char *DefaultXcclApi::getErrorString(onecclResult_t result) { + return onecclGetErrorString(result); +} + +onecclResult_t DefaultXcclApi::setDevice(int device) { + return onecclSetDevice(device); +} + +onecclResult_t DefaultXcclApi::getUniqueId(onecclUniqueId *uniqueId) { + return onecclGetUniqueId(uniqueId); +} + +onecclResult_t DefaultXcclApi::commInitRankConfig(onecclComm_t *comm, + int nranks, + onecclUniqueId commId, + int rank, + onecclConfig_t *config) { + return onecclCommInitRankConfig(comm, nranks, commId, rank, config); +} + +onecclResult_t DefaultXcclApi::commDestroy(onecclComm_t comm) { + return onecclCommDestroy(comm); +} + +onecclResult_t DefaultXcclApi::commAbort(onecclComm_t comm) { + // return onecclCommAbort(comm); + return onecclNotImplemented; +} + +onecclResult_t DefaultXcclApi::commGetAsyncError(onecclComm_t comm, + onecclResult_t *asyncError) { + // return onecclCommGetAsyncError(comm); + return onecclNotImplemented; +} + +onecclResult_t DefaultXcclApi::commSplit(onecclComm_t comm, int color, int key, + onecclComm_t *newcomm, + onecclConfig_t *config) { + return onecclCommSplit(comm, color, key, newcomm, config); +} + +onecclResult_t DefaultXcclApi::commRegister(onecclComm_t comm, void *buffer, + size_t size, void **handle) { + // return onecclCommRegister(comm, buffer, size, handle); + return onecclNotImplemented; +} + +onecclResult_t DefaultXcclApi::commDeregister(onecclComm_t comm, void *handle) { + // return onecclCommDeregister(comm, handle); + return onecclNotImplemented; +} + +onecclResult_t DefaultXcclApi::send(const void *sendbuff, size_t count, + onecclDataType_t datatype, int peer, + onecclComm_t comm, xpuStream_t stream) { + return onecclSend(const_cast(sendbuff), count, datatype, peer, comm, + stream); +} + +onecclResult_t DefaultXcclApi::recv(void *recvbuff, size_t count, + onecclDataType_t datatype, int peer, + onecclComm_t comm, xpuStream_t stream) { + return onecclRecv(recvbuff, count, datatype, peer, comm, stream); +} + +onecclResult_t DefaultXcclApi::broadcast(const void *sendbuff, void *recvbuff, + size_t count, + onecclDataType_t datatype, int root, + onecclComm_t comm, + xpuStream_t stream) { + return onecclBroadcast(const_cast(sendbuff), recvbuff, count, + datatype, root, comm, stream); +} + +onecclResult_t DefaultXcclApi::bcast(void *buff, size_t count, + onecclDataType_t datatype, int root, + onecclComm_t comm, xpuStream_t stream) { + return onecclBroadcast(buff, buff, count, datatype, root, comm, stream); +} + +onecclResult_t DefaultXcclApi::allReduce(const void *sendbuff, void *recvbuff, + size_t count, + onecclDataType_t datatype, + onecclRedOp_t op, onecclComm_t comm, + xpuStream_t stream) { + return onecclAllReduce(const_cast(sendbuff), recvbuff, count, + datatype, op, comm, stream); +} + +onecclResult_t DefaultXcclApi::reduce(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclRedOp_t op, int root, + onecclComm_t comm, xpuStream_t stream) { + return onecclReduce(const_cast(sendbuff), recvbuff, count, datatype, + op, root, comm, stream); +} + +onecclResult_t DefaultXcclApi::allGather(const void *sendbuff, void *recvbuff, + size_t sendcount, + onecclDataType_t datatype, + onecclComm_t comm, + xpuStream_t stream) { + return onecclAllGather(const_cast(sendbuff), recvbuff, sendcount, + datatype, comm, stream); +} + +onecclResult_t DefaultXcclApi::reduceScatter(const void *sendbuff, + void *recvbuff, size_t recvcount, + onecclDataType_t datatype, + onecclRedOp_t op, + onecclComm_t comm, + xpuStream_t stream) { + return onecclReduceScatter(const_cast(sendbuff), recvbuff, recvcount, + datatype, op, comm, stream); +} + +onecclResult_t DefaultXcclApi::allToAll(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclComm_t comm, xpuStream_t stream) { + return onecclAllToAll(const_cast(sendbuff), recvbuff, count, datatype, + comm, stream); +} + +onecclResult_t DefaultXcclApi::groupStart() { return onecclGroupStart(); } + +onecclResult_t DefaultXcclApi::groupEnd() { return onecclGroupEnd(); } + +onecclResult_t DefaultXcclApi::commUserRank(const onecclComm_t comm, + int *myRank) { + return onecclCommUserRank(comm, myRank); +} + +onecclResult_t DefaultXcclApi::commCount(const onecclComm_t comm, int *count) { + return onecclCommCount(comm, count); +} + +onecclResult_t DefaultXcclApi::redOpCreatePreMulSum( + onecclRedOp_t *op, void *scalar, onecclDataType_t datatype, + onecclScalarResidence_t residence, onecclComm_t comm) { + return onecclRedOpCreatePreMulSum(op, scalar, datatype, residence, comm); +} + +onecclResult_t DefaultXcclApi::redOpDestroy(onecclRedOp_t op, + onecclComm_t comm) { + return onecclRedOpDestroy(op, comm); +} + +} // namespace comms +} // namespace torch diff --git a/comms/torchcomms/xccl/XcclApi.hpp b/comms/torchcomms/xccl/XcclApi.hpp new file mode 100644 index 00000000..9e17bb26 --- /dev/null +++ b/comms/torchcomms/xccl/XcclApi.hpp @@ -0,0 +1,189 @@ +#pragma once + +#include +#include + +#include "comms/torchcomms/device/XpuApi.hpp" + +namespace torch { +namespace comms { + +class XcclApi { +public: + virtual ~XcclApi() = default; + + virtual const char *getErrorString(onecclResult_t result) = 0; + + virtual onecclResult_t setDevice(int device) = 0; + + virtual onecclResult_t getUniqueId(onecclUniqueId *uniqueId) = 0; + + virtual onecclResult_t commInitRankConfig(onecclComm_t *comm, int nranks, + onecclUniqueId commId, int rank, + onecclConfig_t *config) = 0; + + virtual onecclResult_t commDestroy(onecclComm_t comm) = 0; + + virtual onecclResult_t commAbort(onecclComm_t comm) = 0; + + virtual onecclResult_t commGetAsyncError(onecclComm_t comm, + onecclResult_t *asyncError) = 0; + + virtual onecclResult_t commSplit(onecclComm_t comm, int color, int key, + onecclComm_t *newcomm, + onecclConfig_t *config) = 0; + + virtual onecclResult_t commRegister(onecclComm_t comm, void *buffer, + size_t size, void **handle) = 0; + + virtual onecclResult_t commDeregister(onecclComm_t comm, void *handle) = 0; + + // Point-to-point operations + virtual onecclResult_t send(const void *sendbuff, size_t count, + onecclDataType_t datatype, int peer, + onecclComm_t comm, xpuStream_t stream) = 0; + + virtual onecclResult_t recv(void *recvbuff, size_t count, + onecclDataType_t datatype, int peer, + onecclComm_t comm, xpuStream_t stream) = 0; + + // Collective operations + virtual onecclResult_t broadcast(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + int root, onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t bcast(void *buff, size_t count, + onecclDataType_t datatype, int root, + onecclComm_t comm, xpuStream_t stream) = 0; + + virtual onecclResult_t allReduce(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclRedOp_t op, onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t reduce(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclRedOp_t op, int root, onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t allGather(const void *sendbuff, void *recvbuff, + size_t sendcount, onecclDataType_t datatype, + onecclComm_t comm, xpuStream_t stream) = 0; + + virtual onecclResult_t reduceScatter(const void *sendbuff, void *recvbuff, + size_t recvcount, + onecclDataType_t datatype, + onecclRedOp_t op, onecclComm_t comm, + xpuStream_t stream) = 0; + + virtual onecclResult_t allToAll(const void *sendbuff, void *recvbuff, + size_t count, onecclDataType_t datatype, + onecclComm_t comm, xpuStream_t stream) = 0; + + // Group operations + virtual onecclResult_t groupStart() = 0; + virtual onecclResult_t groupEnd() = 0; + + virtual onecclResult_t commUserRank(const onecclComm_t comm, + int *userRank) = 0; + virtual onecclResult_t commCount(const onecclComm_t comm, int *count) = 0; + + virtual onecclResult_t redOpCreatePreMulSum(onecclRedOp_t *op, void *scalar, + onecclDataType_t datatype, + onecclScalarResidence_t residence, + onecclComm_t comm) = 0; + virtual onecclResult_t redOpDestroy(onecclRedOp_t op, onecclComm_t comm) = 0; +}; + +/** + * Default implementation that calls the underlying XCCL APIs directly. + */ +class DefaultXcclApi : public XcclApi { +public: + ~DefaultXcclApi() override = default; + + // Error handling + const char *getErrorString(onecclResult_t result) override; + + // Device management + onecclResult_t setDevice(int device) override; + + // Unique ID generation + onecclResult_t getUniqueId(onecclUniqueId *uniqueId) override; + + // Communicator management + onecclResult_t commInitRankConfig(onecclComm_t *comm, int nranks, + onecclUniqueId commId, int rank, + onecclConfig_t *config) override; + + onecclResult_t commDestroy(onecclComm_t comm) override; + + onecclResult_t commAbort(onecclComm_t comm) override; + + onecclResult_t commGetAsyncError(onecclComm_t comm, + onecclResult_t *asyncError) override; + + onecclResult_t commSplit(onecclComm_t comm, int color, int key, + onecclComm_t *newcomm, + onecclConfig_t *config) override; + + onecclResult_t commRegister(onecclComm_t comm, void *buffer, size_t size, + void **handle) override; + + onecclResult_t commDeregister(onecclComm_t comm, void *handle) override; + + // Point-to-point operations + onecclResult_t send(const void *sendbuff, size_t count, + onecclDataType_t datatype, int peer, onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t recv(void *recvbuff, size_t count, onecclDataType_t datatype, + int peer, onecclComm_t comm, xpuStream_t stream) override; + + // Collective operations + onecclResult_t broadcast(const void *sendbuff, void *recvbuff, size_t count, + onecclDataType_t datatype, int root, + onecclComm_t comm, xpuStream_t stream) override; + + onecclResult_t bcast(void *buff, size_t count, onecclDataType_t datatype, + int root, onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t allReduce(const void *sendbuff, void *recvbuff, size_t count, + onecclDataType_t datatype, onecclRedOp_t op, + onecclComm_t comm, xpuStream_t stream) override; + + onecclResult_t reduce(const void *sendbuff, void *recvbuff, size_t count, + onecclDataType_t datatype, onecclRedOp_t op, int root, + onecclComm_t comm, xpuStream_t stream) override; + + onecclResult_t allGather(const void *sendbuff, void *recvbuff, + size_t sendcount, onecclDataType_t datatype, + onecclComm_t comm, xpuStream_t stream) override; + + onecclResult_t reduceScatter(const void *sendbuff, void *recvbuff, + size_t recvcount, onecclDataType_t datatype, + onecclRedOp_t op, onecclComm_t comm, + xpuStream_t stream) override; + + onecclResult_t allToAll(const void *sendbuff, void *recvbuff, size_t count, + onecclDataType_t datatype, onecclComm_t comm, + xpuStream_t stream) override; + + // Group operations + onecclResult_t groupStart() override; + onecclResult_t groupEnd() override; + + onecclResult_t commUserRank(const onecclComm_t comm, int *userRank) override; + onecclResult_t commCount(const onecclComm_t comm, int *count) override; + + onecclResult_t redOpCreatePreMulSum(onecclRedOp_t *op, void *scalar, + onecclDataType_t datatype, + onecclScalarResidence_t residence, + onecclComm_t comm) override; + onecclResult_t redOpDestroy(onecclRedOp_t op, onecclComm_t comm) override; +}; + +} // namespace comms +} // namespace torch diff --git a/setup.py b/setup.py index a4e6842e..5d064511 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ def flag_str(val: bool): USE_GLOO = flag_enabled("USE_GLOO", True) USE_RCCL = flag_enabled("USE_RCCL", False) USE_RCCLX = flag_enabled("USE_RCCLX", False) +USE_XCCL = flag_enabled("USE_XCCL", False) requirement_path = os.path.join(ROOT, "requirements.txt") try: @@ -107,6 +108,7 @@ def build_cmake(self, ext): f"-DUSE_GLOO={flag_str(USE_GLOO)}", f"-DUSE_RCCL={flag_str(USE_RCCL)}", f"-DUSE_RCCLX={flag_str(USE_RCCLX)}", + f"-DUSE_XCCL={flag_str(USE_XCCL)}", ] build_args = ["--", "-j"] @@ -150,6 +152,9 @@ def build_cmake(self, ext): if USE_RCCLX: ext_modules += [ CMakeExtension("torchcomms._comms_rcclx"), +if USE_XCCL: + ext_modules += [ + CMakeExtension("torchcomms._comms_xccl"), ] setup( @@ -164,6 +169,7 @@ def build_cmake(self, ext): "gloo = torchcomms._comms_gloo", "rccl = torchcomms._comms_rccl", "rcclx = torchcomms._comms_rcclx", + "xccl = torchcomms._comms_xccl", ] }, ext_modules=ext_modules,