From ead388343c69c1b412ec17e61034e4c40e5a53d8 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 12:05:36 -0500 Subject: [PATCH 01/16] Add RDMA host-side code --- .../iris_rdma/python/bindings.cpp | 152 ++++ iris/experimental/iris_rdma/src/ibv_utils.hpp | 94 +++ .../iris_rdma/src/network_backend.hpp | 745 ++++++++++++++++++ .../experimental/iris_rdma/src/queue_pair.hpp | 112 +++ .../iris_rdma/src/torch_bootstrap.hpp | 122 +++ 5 files changed, 1225 insertions(+) create mode 100644 iris/experimental/iris_rdma/python/bindings.cpp create mode 100644 iris/experimental/iris_rdma/src/ibv_utils.hpp create mode 100644 iris/experimental/iris_rdma/src/network_backend.hpp create mode 100644 iris/experimental/iris_rdma/src/queue_pair.hpp create mode 100644 iris/experimental/iris_rdma/src/torch_bootstrap.hpp diff --git a/iris/experimental/iris_rdma/python/bindings.cpp b/iris/experimental/iris_rdma/python/bindings.cpp new file mode 100644 index 00000000..a50ea69e --- /dev/null +++ b/iris/experimental/iris_rdma/python/bindings.cpp @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +/****************************************************************************** + * Python Bindings for Iris RDMA Backend using PyBind11 + *****************************************************************************/ + +#include +#include +#include +#include + +#include "network_backend.hpp" +#include "queue_pair.hpp" +#include "torch_bootstrap.hpp" + +namespace py = pybind11; +using namespace iris_rdma; + +PYBIND11_MODULE(_iris_rdma_backend, m) { + m.doc() = + "Iris RDMA Backend: InfiniBand RDMA with PyTorch Integration"; + + // Expose NICVendor enum + py::enum_(m, "NICVendor") + .value("NONE", NICVendor::NONE) + .value("IONIC", NICVendor::IONIC) + .value("BNXT", NICVendor::BNXT) + .value("MLX5", NICVendor::MLX5) + .export_values(); + + // Expose QPInfo struct + py::class_(m, "QPInfo") + .def(py::init<>()) + .def_readwrite("qp_num", &QPInfo::qp_num) + .def_readwrite("lkey", &QPInfo::lkey) + .def_readwrite("rkey", &QPInfo::rkey) + .def_readwrite("dst_rank", &QPInfo::dst_rank) + .def("__repr__", [](const QPInfo& info) { + return ""; + }); + + // Expose TorchBootstrap + py::class_>(m, + "TorchBootstrap") + .def(py::init([](py::object pg_obj) { + // Extract c10d::ProcessGroup from Python object + auto pg_ptr = + pg_obj.cast>(); + return std::make_shared(pg_ptr); + }), + py::arg("process_group")) + .def("get_rank", &TorchBootstrap::getRank) + .def("get_world_size", &TorchBootstrap::getWorldSize) + .def("barrier", &TorchBootstrap::barrier); + + // Expose QueuePair (read-only access) + py::class_(m, "QueuePair") + .def("get_qp_num", &QueuePair::getQPNum) + .def("get_lkey", &QueuePair::getLKey) + .def("get_rkey", &QueuePair::getRKey) + .def("get_dst_rank", &QueuePair::getDstRank) + .def("get_info", &QueuePair::getInfo) + .def("__repr__", [](const QueuePair& qp) { + return ""; + }); + + // Expose NetworkBackend + py::class_(m, "NetworkBackend") + .def(py::init, const char*>(), + py::arg("bootstrap"), py::arg("device_name") = nullptr, + "Create NetworkBackend with PyTorch bootstrap") + .def("init", &NetworkBackend::init, + "Initialize the network (setup QPs, transition to RTS)") + .def( + "register_memory", + [](NetworkBackend& self, py::object obj, size_t size = 0) { + void* ptr = nullptr; + size_t actual_size = size; + + // Check if it's an integer (raw pointer) + if (PyLong_Check(obj.ptr())) { + ptr = reinterpret_cast(PyLong_AsVoidPtr(obj.ptr())); + if (size == 0) { + throw std::runtime_error("Size must be specified for raw pointer"); + } + actual_size = size; + } + // Check if it's a PyTorch tensor + else if (THPVariable_Check(obj.ptr())) { + auto t = THPVariable_Unpack(obj.ptr()); + ptr = t.data_ptr(); + actual_size = t.numel() * t.element_size(); + + // Note: For GPU tensors, ibv_reg_mr will work if: + // 1. GPUDirect RDMA is enabled (check with ibstat/ibv_devinfo) + // 2. The memory is allocated with hipMalloc (native GPU memory) + // PyTorch tensors should work as they use hipMalloc internally + } + else { + throw std::runtime_error("Expected a PyTorch tensor or integer address"); + } + + self.registerMemory(ptr, actual_size); + }, + py::arg("obj"), py::arg("size") = 0, + "Register memory for RDMA (supports CPU pinned or GPU memory via GPUDirect)") + .def("get_qp", &NetworkBackend::getQP, py::arg("dst_rank"), + py::return_value_policy::reference_internal, + "Get queue pair for destination rank") + .def("get_qp_info", &NetworkBackend::getQPInfo, py::arg("dst_rank"), + "Get QP info for destination rank") + .def("get_rank", &NetworkBackend::getRank, "Get rank") + .def("get_world_size", &NetworkBackend::getWorldSize, "Get world size") + .def("get_remote_heap_base", &NetworkBackend::getRemoteHeapBase, + py::arg("rank"), + "Get remote heap base address for a rank") + .def("get_heap_base", &NetworkBackend::getHeapBase, + "Get local heap base address") + .def("get_heap_size", &NetworkBackend::getHeapSize, + "Get heap size in bytes") + .def("rdma_write", + [](NetworkBackend& self, int dst_rank, uint64_t local_addr, + uint64_t remote_addr, size_t size, uint64_t wr_id) { + return self.rdmaWrite(dst_rank, reinterpret_cast(local_addr), + remote_addr, size, wr_id); + }, + py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), + py::arg("size"), py::arg("wr_id") = 0, + "RDMA write to remote rank (local_addr is integer address)") + .def("rdma_read", + [](NetworkBackend& self, int dst_rank, uint64_t local_addr, + uint64_t remote_addr, size_t size, uint64_t wr_id) { + return self.rdmaRead(dst_rank, reinterpret_cast(local_addr), + remote_addr, size, wr_id); + }, + py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), + py::arg("size"), py::arg("wr_id") = 0, + "RDMA read from remote rank (local_addr is integer address)") + .def("poll_cq", &NetworkBackend::pollCQ, + py::arg("dst_rank"), py::arg("max_completions") = 1, + "Poll completion queue for RDMA operations") + .def("__repr__", [](const NetworkBackend& backend) { + return ""; + }); +} + diff --git a/iris/experimental/iris_rdma/src/ibv_utils.hpp b/iris/experimental/iris_rdma/src/ibv_utils.hpp new file mode 100644 index 00000000..ee55544f --- /dev/null +++ b/iris/experimental/iris_rdma/src/ibv_utils.hpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include +#include +#include + +namespace iris_rdma { + +// Error checking macros +#define CHECK_ZERO(expr, msg) \ + do { \ + int ret = (expr); \ + if (ret != 0) { \ + fprintf(stderr, "[ERROR] %s failed with code %d: %s\n", msg, ret, \ + strerror(ret)); \ + abort(); \ + } \ + } while (0) + +#define CHECK_NNULL(ptr, msg) \ + do { \ + if ((ptr) == nullptr) { \ + fprintf(stderr, "[ERROR] %s returned NULL\n", msg); \ + abort(); \ + } \ + } while (0) + +#define DEBUG_PRINT(fmt, ...) \ + do { \ + if (getenv("IRIS_RDMA_DEBUG")) { \ + fprintf(stderr, "[IRIS_RDMA_DEBUG] " fmt "\n", ##__VA_ARGS__); \ + } \ + } while (0) + +// Vendor detection +enum class NICVendor { NONE, IONIC, BNXT, MLX5 }; + +// QP destination info for connection +struct QPDestInfo { + int lid; + int qpn; + int psn; + union ibv_gid gid; +}; + +// QP metadata exposed to Python +struct QPInfo { + uint32_t qp_num; + uint32_t lkey; + uint32_t rkey; + int dst_rank; +}; + +// Helper functions +inline void dump_ibv_device(struct ibv_device* device) { + DEBUG_PRINT("IBV Device: %s", ibv_get_device_name(device)); +} + +inline void dump_ibv_context(struct ibv_context* ctx) { + DEBUG_PRINT("IBV Context: device=%s", ctx->device->name); +} + +inline void dump_ibv_pd(struct ibv_pd* pd) { + DEBUG_PRINT("IBV PD: handle=%u", pd->handle); +} + +inline void dump_ibv_port_attr(struct ibv_port_attr* attr) { + DEBUG_PRINT("Port Attr: state=%d, lid=%d, link_layer=%d, active_mtu=%d", + attr->state, attr->lid, attr->link_layer, attr->active_mtu); +} + +inline int ibv_mtu_to_int(enum ibv_mtu mtu) { + switch (mtu) { + case IBV_MTU_256: + return 256; + case IBV_MTU_512: + return 512; + case IBV_MTU_1024: + return 1024; + case IBV_MTU_2048: + return 2048; + case IBV_MTU_4096: + return 4096; + default: + fprintf(stderr, "[ERROR] Invalid ibv_mtu\n"); + return 0; + } +} + +} // namespace iris_rdma + diff --git a/iris/experimental/iris_rdma/src/network_backend.hpp b/iris/experimental/iris_rdma/src/network_backend.hpp new file mode 100644 index 00000000..8fee042e --- /dev/null +++ b/iris/experimental/iris_rdma/src/network_backend.hpp @@ -0,0 +1,745 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ibv_utils.hpp" +#include "queue_pair.hpp" +#include "torch_bootstrap.hpp" + +// Vendor-specific headers +#ifdef HAVE_MLX5 +#include +#endif + +#ifdef HAVE_BNXT +#include +#endif + +namespace iris_rdma { + +/** + * @brief Main network backend for InfiniBand setup + * + * Handles: + * - Device detection and initialization + * - Protection domain creation + * - Queue pair creation and state transitions + * - Memory registration + * - QP connection info exchange + */ +class NetworkBackend { + public: + /** + * @brief Constructor + * @param bootstrap PyTorch bootstrap for cross-rank communication + * @param device_name Optional device name (NULL for auto-detect) + */ + NetworkBackend(std::shared_ptr bootstrap, + const char* device_name = nullptr) + : bootstrap_(bootstrap), + requested_dev_(device_name), + context_(nullptr), + pd_orig_(nullptr), + pd_parent_(nullptr), + vendor_(NICVendor::NONE), + port_(1), + gid_index_(0), + heap_mr_(nullptr), + heap_base_(0), + heap_size_(0), + mlx5dv_handle_(nullptr), + bnxtdv_handle_(nullptr) { + if (!bootstrap_) { + throw std::runtime_error("Bootstrap cannot be null"); + } + rank_ = bootstrap_->getRank(); + world_size_ = bootstrap_->getWorldSize(); + DEBUG_PRINT("NetworkBackend created: rank=%d, world_size=%d", rank_, world_size_); + } + + /** + * @brief Destructor - cleanup InfiniBand resources + */ + ~NetworkBackend() { + DEBUG_PRINT("NetworkBackend cleanup started"); + + qps_.clear(); + + for (auto* cq : cqs_) { + if (cq) { + ibv_destroy_cq(cq); + } + } + cqs_.clear(); + + if (heap_mr_) { + ibv_dereg_mr(heap_mr_); + heap_mr_ = nullptr; + } + + if (pd_parent_) { + ibv_dealloc_pd(pd_parent_); + pd_parent_ = nullptr; + } + + if (pd_orig_) { + ibv_dealloc_pd(pd_orig_); + pd_orig_ = nullptr; + } + + if (context_) { + ibv_close_device(context_); + context_ = nullptr; + } + + if (mlx5dv_handle_) { + dlclose(mlx5dv_handle_); + mlx5dv_handle_ = nullptr; + } + + if (bnxtdv_handle_) { + dlclose(bnxtdv_handle_); + bnxtdv_handle_ = nullptr; + } + + DEBUG_PRINT("NetworkBackend cleanup completed"); + } + + /** + * @brief Initialize the network (setup QPs, transition to RTS) + */ + void init() { + DEBUG_PRINT("NetworkBackend::init() started"); + + autodetectDVLibs(); + openIBDevice(); + createQueues(); + exchangeQPDestInfo(); + modifyQPsResetToInit(); + modifyQPsInitToRTR(); + modifyQPsRTRToRTS(); + bootstrap_->barrier(); + + DEBUG_PRINT("NetworkBackend::init() completed"); + } + + /** + * @brief Register memory for RDMA + * @param ptr Pointer to memory region + * @param size Size in bytes + */ + void registerMemory(void* ptr, size_t size) { + DEBUG_PRINT("Registering memory: ptr=%p, size=%zu", ptr, size); + + int access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC; + + heap_mr_ = ibv_reg_mr(pd_orig_, ptr, size, access); + if (heap_mr_ == nullptr) { + int err = errno; + fprintf(stderr, "[ERROR] ibv_reg_mr returned NULL for ptr=%p, size=%zu, errno=%d (%s)\n", + ptr, size, err, strerror(err)); + char error_msg[256]; + snprintf(error_msg, sizeof(error_msg), + "ibv_reg_mr failed with errno %d (%s) - GPUDirect RDMA may not be enabled", + err, strerror(err)); + throw std::runtime_error(error_msg); + } + + // Store local heap base + heap_base_ = reinterpret_cast(ptr); + heap_size_ = size; + + // Exchange remote keys + rkeys_.resize(world_size_); + std::vector all_rkeys(world_size_); + all_rkeys[rank_] = heap_mr_->rkey; + bootstrap_->allGather(all_rkeys.data(), sizeof(uint32_t)); + for (int i = 0; i < world_size_; i++) { + rkeys_[i] = all_rkeys[i]; + } + + // Exchange heap base addresses (collective operation) + remote_heap_bases_.resize(world_size_); + std::vector all_heap_bases(world_size_); + all_heap_bases[rank_] = heap_base_; + bootstrap_->allGather(all_heap_bases.data(), sizeof(uint64_t)); + for (int i = 0; i < world_size_; i++) { + remote_heap_bases_[i] = all_heap_bases[i]; + } + + // Update QPs with lkey and rkey + uint32_t lkey = heap_mr_->lkey; + for (int i = 0; i < world_size_; i++) { + if (i < qps_.size() && qps_[i]) { + qps_[i]->setLKey(lkey); + qps_[i]->setRKey(rkeys_[i]); + } + } + + DEBUG_PRINT("Memory registered: lkey=%u, rkey=%u, heap_base=%p", + lkey, heap_mr_->rkey, ptr); + } + + /** + * @brief Get queue pair for destination rank + * @param dst_rank Destination rank + * @return Pointer to QueuePair object + */ + QueuePair* getQP(int dst_rank) { + if (dst_rank >= 0 && dst_rank < qps_.size()) { + return qps_[dst_rank].get(); + } + return nullptr; + } + + /** + * @brief Get QP info for Python + * @param dst_rank Destination rank + * @return QPInfo structure + */ + QPInfo getQPInfo(int dst_rank) { + QueuePair* qp = getQP(dst_rank); + if (qp) { + return qp->getInfo(); + } + return QPInfo{0, 0, 0, dst_rank}; + } + + + + + /** + * @brief Get rank + */ + int getRank() const { return rank_; } + + /** + * @brief Get world size + */ + int getWorldSize() const { return world_size_; } + + /** + * @brief Get remote heap base address for a rank + * @param rank Remote rank + * @return Remote heap base address (0 if not registered) + */ + uint64_t getRemoteHeapBase(int rank) const { + if (rank >= 0 && rank < remote_heap_bases_.size()) { + return remote_heap_bases_[rank]; + } + return 0; + } + + /** + * @brief Get local heap base address + * @return Local heap base address (0 if not registered) + */ + uint64_t getHeapBase() const { return heap_base_; } + + /** + * @brief Get heap size + * @return Heap size in bytes (0 if not registered) + */ + size_t getHeapSize() const { return heap_size_; } + + /** + * @brief RDMA Write operation + * @param dst_rank Destination rank + * @param local_addr Local buffer address + * @param remote_addr Remote buffer address + * @param size Size in bytes + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdmaWrite(int dst_rank, void* local_addr, uint64_t remote_addr, + size_t size, uint64_t wr_id = 0) { + QueuePair* qp = getQP(dst_rank); + if (!qp) { + return -1; + } + + struct ibv_sge sge; + sge.addr = (uintptr_t)local_addr; + sge.length = size; + sge.lkey = qp->getLKey(); + + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = wr_id; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.rdma.remote_addr = remote_addr; + wr.wr.rdma.rkey = qp->getRKey(); + + struct ibv_send_wr* bad_wr; + int ret = ibv_post_send(qp->getIBVQP(), &wr, &bad_wr); + + DEBUG_PRINT("RDMA Write to rank %d: local=%p remote=%lx size=%zu ret=%d", + dst_rank, local_addr, remote_addr, size, ret); + + return ret; + } + + /** + * @brief RDMA Read operation + * @param dst_rank Destination rank + * @param local_addr Local buffer address + * @param remote_addr Remote buffer address + * @param size Size in bytes + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdmaRead(int dst_rank, void* local_addr, uint64_t remote_addr, + size_t size, uint64_t wr_id = 0) { + QueuePair* qp = getQP(dst_rank); + if (!qp) { + return -1; + } + + struct ibv_sge sge; + sge.addr = (uintptr_t)local_addr; + sge.length = size; + sge.lkey = qp->getLKey(); + + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = wr_id; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_READ; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.rdma.remote_addr = remote_addr; + wr.wr.rdma.rkey = qp->getRKey(); + + struct ibv_send_wr* bad_wr; + int ret = ibv_post_send(qp->getIBVQP(), &wr, &bad_wr); + + DEBUG_PRINT("RDMA Read from rank %d: local=%p remote=%lx size=%zu ret=%d", + dst_rank, local_addr, remote_addr, size, ret); + + return ret; + } + + /** + * @brief Poll completion queue for RDMA operations + * @param dst_rank Destination rank (to poll specific CQ) + * @param max_completions Maximum number of completions to poll + * @return Number of completions polled (negative on error) + */ + int pollCQ(int dst_rank, int max_completions = 1) { + QueuePair* qp = getQP(dst_rank); + if (!qp) { + return -1; + } + + struct ibv_wc wc[16]; + int num_to_poll = (max_completions < 16) ? max_completions : 16; + int n = ibv_poll_cq(qp->getIBVCQ(), num_to_poll, wc); + + if (n < 0) { + DEBUG_PRINT("CQ poll error for rank %d", dst_rank); + return n; + } + + // Check for errors in completions + for (int i = 0; i < n; i++) { + if (wc[i].status != IBV_WC_SUCCESS) { + fprintf(stderr, "[ERROR] Work completion failed: status=%d (%s) wr_id=%lu\n", + wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].wr_id); + return -1; + } + } + + DEBUG_PRINT("Polled %d completions from rank %d", n, dst_rank); + return n; + } + + + + private: + // Bootstrap + std::shared_ptr bootstrap_; + int rank_; + int world_size_; + + // Device configuration + const char* requested_dev_; + struct ibv_context* context_; + struct ibv_pd* pd_orig_; + struct ibv_pd* pd_parent_; // For MLX5/IONIC + NICVendor vendor_; + + // Port configuration + struct ibv_port_attr portinfo_; + union ibv_gid gid_; + int port_; + int gid_index_; + + // Memory registration + struct ibv_mr* heap_mr_; + std::vector rkeys_; // Remote keys from all ranks + uint64_t heap_base_; // Local heap base address + size_t heap_size_; // Local heap size + std::vector remote_heap_bases_; // Heap base addresses from all ranks + + // Queue pairs + std::vector> qps_; + std::vector cqs_; + std::vector dest_info_; + + // Dynamic library handles for vendor-specific libraries + void* mlx5dv_handle_; + void* bnxtdv_handle_; + + // Setup functions (extracted from rocSHMEM) + + // Vendor-specific init + void autodetectDVLibs() { + DEBUG_PRINT("Auto-detecting vendor libraries..."); + + // Try MLX5 + if (mlx5DVDLInit() == 0) { + vendor_ = NICVendor::MLX5; + DEBUG_PRINT("Detected MLX5 vendor"); + return; + } + + // Try BNXT + if (bnxtDVDLInit() == 0) { + vendor_ = NICVendor::BNXT; + DEBUG_PRINT("Detected BNXT vendor"); + return; + } + + // Default to standard verbs + vendor_ = NICVendor::NONE; + DEBUG_PRINT("Using standard InfiniBand verbs"); + } + + int mlx5DVDLInit() { + mlx5dv_handle_ = dlopen("libmlx5.so", RTLD_NOW); + if (!mlx5dv_handle_) { + mlx5dv_handle_ = dlopen("libmlx5.so.1", RTLD_NOW); + } + + if (!mlx5dv_handle_) { + DEBUG_PRINT("Could not open libmlx5.so"); + return -1; + } + + return 0; + } + + int bnxtDVDLInit() { + bnxtdv_handle_ = dlopen("libbnxt_re.so", RTLD_NOW); + if (!bnxtdv_handle_) { + bnxtdv_handle_ = dlopen("/usr/local/lib/libbnxt_re.so", RTLD_NOW); + } + + if (!bnxtdv_handle_) { + DEBUG_PRINT("Could not open libbnxt_re.so"); + return -1; + } + + return 0; + } + + void openIBDevice() { + DEBUG_PRINT("Opening InfiniBand device..."); + + struct ibv_device** device_list = nullptr; + struct ibv_device* device = nullptr; + int num_devices = 0; + + device_list = ibv_get_device_list(&num_devices); + CHECK_NNULL(device_list, "ibv_get_device_list"); + + if (num_devices == 0) { + throw std::runtime_error("No InfiniBand devices found"); + } + + // Select device + device = device_list[0]; // Default to first device + + if (requested_dev_) { + for (int i = 0; i < num_devices; i++) { + const char* dev_name = ibv_get_device_name(device_list[i]); + CHECK_NNULL(dev_name, "ibv_get_device_name"); + + if (strstr(dev_name, requested_dev_)) { + device = device_list[i]; + break; + } + } + } + + // Open device + context_ = ibv_open_device(device); + CHECK_NNULL(context_, "ibv_open_device"); + dump_ibv_context(context_); + dump_ibv_device(context_->device); + + // Allocate protection domain + pd_orig_ = ibv_alloc_pd(context_); + CHECK_NNULL(pd_orig_, "ibv_alloc_pd"); + dump_ibv_pd(pd_orig_); + + // Create parent domain for MLX5/IONIC + if (vendor_ == NICVendor::MLX5) { + createParentDomain(); + } + + // Query port + int err = ibv_query_port(context_, port_, &portinfo_); + CHECK_ZERO(err, "ibv_query_port"); + dump_ibv_port_attr(&portinfo_); + + // Select GID index + selectGIDIndex(); + + ibv_free_device_list(device_list); + + DEBUG_PRINT("InfiniBand device opened: %s", + ibv_get_device_name(context_->device)); + } + + void createParentDomain() { + DEBUG_PRINT("Creating parent domain..."); + + struct ibv_parent_domain_init_attr pattr; + memset(&pattr, 0, sizeof(pattr)); + + pattr.pd = pd_orig_; + pattr.td = nullptr; + pattr.comp_mask = 0; + + pd_parent_ = ibv_alloc_parent_domain(context_, &pattr); + CHECK_NNULL(pd_parent_, "ibv_alloc_parent_domain"); + dump_ibv_pd(pd_parent_); + } + + void selectGIDIndex() { + DEBUG_PRINT("Selecting GID index..."); + + const uint8_t local_gid_prefix[2] = {0xFE, 0x80}; + int selected_gid_index = -1; + union ibv_gid selected_gid; + int err; + + int gid_tbl_len = portinfo_.gid_tbl_len; + + for (int i = 0; i < gid_tbl_len; i++) { + union ibv_gid current_gid; + err = ibv_query_gid(context_, port_, i, ¤t_gid); + if (err != 0) + continue; + + // Skip local GIDs + if (memcmp(current_gid.raw, &local_gid_prefix, 2) == 0) { + continue; + } + + // Use first non-local GID + if (selected_gid_index == -1) { + selected_gid_index = i; + selected_gid = current_gid; + break; + } + } + + if (selected_gid_index == -1) { + selected_gid_index = 0; + err = ibv_query_gid(context_, port_, 0, &selected_gid); + CHECK_ZERO(err, "ibv_query_gid"); + } + + gid_index_ = selected_gid_index; + gid_ = selected_gid; + + DEBUG_PRINT("Selected GID index: %d", gid_index_); + } + + void createQueues() { + DEBUG_PRINT("Creating queues..."); + + int ncqes = 64; // Number of CQ entries + int sq_length = 64; // Send queue length + + // Resize vectors + dest_info_.resize(world_size_); + cqs_.resize(world_size_); + qps_.resize(world_size_); + + // Create CQs and QPs + createCQs(ncqes); + createQPs(sq_length); + + DEBUG_PRINT("Created %d queue pairs", world_size_); + } + + void createCQs(int ncqes) { + DEBUG_PRINT("Creating completion queues: ncqes=%d", ncqes); + + struct ibv_cq_init_attr_ex cq_attr; + memset(&cq_attr, 0, sizeof(cq_attr)); + + cq_attr.cqe = ncqes; + cq_attr.cq_context = nullptr; + cq_attr.channel = nullptr; + cq_attr.comp_vector = 0; + cq_attr.flags = 0; + + if (pd_parent_) { + cq_attr.comp_mask = IBV_CQ_INIT_ATTR_MASK_PD; + cq_attr.parent_domain = pd_parent_; + } + + for (int i = 0; i < world_size_; i++) { + struct ibv_cq_ex* cq_ex = ibv_create_cq_ex(context_, &cq_attr); + CHECK_NNULL(cq_ex, "ibv_create_cq_ex"); + + cqs_[i] = ibv_cq_ex_to_cq(cq_ex); + CHECK_NNULL(cqs_[i], "ibv_cq_ex_to_cq"); + } + } + + void createQPs(int sq_length) { + DEBUG_PRINT("Creating queue pairs: sq_length=%d", sq_length); + + struct ibv_qp_init_attr_ex attr; + memset(&attr, 0, sizeof(attr)); + + attr.cap.max_send_wr = sq_length; + attr.cap.max_send_sge = 1; + attr.cap.max_inline_data = 8; + attr.sq_sig_all = 0; + attr.qp_type = IBV_QPT_RC; + attr.comp_mask = IBV_QP_INIT_ATTR_PD; + attr.pd = pd_parent_ ? pd_parent_ : pd_orig_; + + for (int i = 0; i < world_size_; i++) { + attr.send_cq = cqs_[i]; + attr.recv_cq = cqs_[i]; + + struct ibv_qp* qp = ibv_create_qp_ex(context_, &attr); + CHECK_NNULL(qp, "ibv_create_qp_ex"); + + qps_[i] = std::make_unique(qp, cqs_[i], i, vendor_); + } + } + + void exchangeQPDestInfo() { + DEBUG_PRINT("Exchanging QP destination info..."); + + // Fill local dest info + for (int i = 0; i < world_size_; i++) { + dest_info_[i].lid = portinfo_.lid; + dest_info_[i].qpn = qps_[i]->getQPNum(); + dest_info_[i].psn = 0; + dest_info_[i].gid = gid_; + } + + // All-gather dest info + bootstrap_->allGather(dest_info_.data(), sizeof(QPDestInfo)); + + DEBUG_PRINT("QP destination info exchanged"); + } + + void modifyQPsResetToInit() { + DEBUG_PRINT("Transitioning QPs: RESET -> INIT"); + + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = 0; + attr.port_num = port_; + attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC; + + int attr_mask = + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS; + + for (int i = 0; i < world_size_; i++) { + int err = ibv_modify_qp(qps_[i]->getIBVQP(), &attr, attr_mask); + CHECK_ZERO(err, "modify_qp (RESET->INIT)"); + } + } + + void modifyQPsInitToRTR() { + DEBUG_PRINT("Transitioning QPs: INIT -> RTR"); + + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + + attr.qp_state = IBV_QPS_RTR; + attr.path_mtu = portinfo_.active_mtu; + attr.min_rnr_timer = 12; + attr.max_dest_rd_atomic = 1; + attr.ah_attr.port_num = port_; + + if (portinfo_.link_layer == IBV_LINK_LAYER_ETHERNET) { + attr.ah_attr.grh.sgid_index = gid_index_; + attr.ah_attr.is_global = 1; + attr.ah_attr.grh.hop_limit = 1; + attr.ah_attr.sl = 1; + attr.ah_attr.grh.traffic_class = 0; + } + + int attr_mask = IBV_QP_STATE | IBV_QP_PATH_MTU | IBV_QP_RQ_PSN | + IBV_QP_DEST_QPN | IBV_QP_AV | IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER; + + for (int i = 0; i < world_size_; i++) { + attr.rq_psn = dest_info_[i].psn; + attr.dest_qp_num = dest_info_[i].qpn; + + if (portinfo_.link_layer == IBV_LINK_LAYER_ETHERNET) { + memcpy(&attr.ah_attr.grh.dgid, &dest_info_[i].gid, 16); + } else { + attr.ah_attr.dlid = dest_info_[i].lid; + } + + int err = ibv_modify_qp(qps_[i]->getIBVQP(), &attr, attr_mask); + CHECK_ZERO(err, "modify_qp (INIT->RTR)"); + } + } + + void modifyQPsRTRToRTS() { + DEBUG_PRINT("Transitioning QPs: RTR -> RTS"); + + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + + attr.qp_state = IBV_QPS_RTS; + attr.timeout = 14; + attr.retry_cnt = 7; + attr.rnr_retry = 7; + attr.max_rd_atomic = 1; + + int attr_mask = IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC | + IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY; + + for (int i = 0; i < world_size_; i++) { + attr.sq_psn = dest_info_[i].psn; + + int err = ibv_modify_qp(qps_[i]->getIBVQP(), &attr, attr_mask); + CHECK_ZERO(err, "modify_qp (RTR->RTS)"); + } + } + +}; + +} // namespace iris_rdma diff --git a/iris/experimental/iris_rdma/src/queue_pair.hpp b/iris/experimental/iris_rdma/src/queue_pair.hpp new file mode 100644 index 00000000..dff923c9 --- /dev/null +++ b/iris/experimental/iris_rdma/src/queue_pair.hpp @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include "ibv_utils.hpp" + +namespace iris_rdma { + +/** + * @brief Simplified Queue Pair wrapper for host-side operations + * + * Unlike the full rocSHMEM QueuePair, this version only maintains + * metadata needed for RDMA operations from Python/host code. + */ +class QueuePair { + public: + /** + * @brief Constructor + * @param qp InfiniBand queue pair + * @param cq InfiniBand completion queue + * @param dst_rank Destination rank for this QP + * @param vendor NIC vendor type + */ + inline QueuePair(struct ibv_qp* qp, + struct ibv_cq* cq, + int dst_rank, + NICVendor vendor) + : qp_(qp), + cq_(cq), + dst_rank_(dst_rank), + vendor_(vendor), + lkey_(0), + rkey_(0) { + CHECK_NNULL(qp_, "QueuePair: ibv_qp"); + CHECK_NNULL(cq_, "QueuePair: ibv_cq"); + qp_num_ = qp_->qp_num; + DEBUG_PRINT("QueuePair created: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); + } + + /** + * @brief Destructor + */ + inline ~QueuePair() { + DEBUG_PRINT("QueuePair destroyed: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); + } + + /** + * @brief Get QP number + */ + uint32_t getQPNum() const { return qp_num_; } + + /** + * @brief Get local key for memory region + */ + uint32_t getLKey() const { return lkey_; } + + /** + * @brief Get remote key for destination rank + */ + uint32_t getRKey() const { return rkey_; } + + /** + * @brief Get destination rank + */ + int getDstRank() const { return dst_rank_; } + + /** + * @brief Set remote key (after exchange) + */ + void setRKey(uint32_t rkey) { rkey_ = rkey; } + + /** + * @brief Set local key (from memory registration) + */ + void setLKey(uint32_t lkey) { lkey_ = lkey; } + + /** + * @brief Get underlying ibv_qp pointer + */ + struct ibv_qp* getIBVQP() { return qp_; } + + /** + * @brief Get underlying ibv_cq pointer + */ + struct ibv_cq* getIBVCQ() { return cq_; } + + /** + * @brief Get QP info for Python + */ + inline QPInfo getInfo() const { + QPInfo info; + info.qp_num = qp_num_; + info.lkey = lkey_; + info.rkey = rkey_; + info.dst_rank = dst_rank_; + return info; + } + + private: + struct ibv_qp* qp_; + struct ibv_cq* cq_; + int dst_rank_; + NICVendor vendor_; + + uint32_t qp_num_; + uint32_t lkey_; + uint32_t rkey_; +}; + +} // namespace iris_rdma + diff --git a/iris/experimental/iris_rdma/src/torch_bootstrap.hpp b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp new file mode 100644 index 00000000..e714b8dd --- /dev/null +++ b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ibv_utils.hpp" + +namespace iris_rdma { + +/** + * @brief Bootstrap implementation using PyTorch Distributed + * + * Wraps PyTorch's c10d process group to provide synchronization + * primitives needed for InfiniBand setup (allGather, barrier) + */ +class TorchBootstrap { + public: + /** + * @brief Constructor + * @param process_group PyTorch distributed process group + */ + inline explicit TorchBootstrap(c10::intrusive_ptr process_group) + : process_group_(process_group) { + if (!process_group_) { + throw std::runtime_error("Process group cannot be null"); + } + rank_ = process_group_->getRank(); + world_size_ = process_group_->getSize(); + DEBUG_PRINT("TorchBootstrap initialized: rank=%d, world_size=%d", rank_, world_size_); + } + + /** + * @brief Get rank of current process + */ + int getRank() const { return rank_; } + + /** + * @brief Get total number of ranks + */ + int getWorldSize() const { return world_size_; } + + /** + * @brief All-gather operation + * + * Gathers data from all ranks. Each rank contributes 'size' bytes + * starting at allData[rank * size]. + * + * @param allData Buffer to hold all gathered data (world_size * size bytes) + * @param size Size of data contributed by each rank + */ + inline void allGather(void* allData, int size) { + auto cpu_options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto cuda_options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + + auto cpu_input = torch::from_blob( + static_cast(allData) + rank_ * size, {size}, cpu_options); + auto input = cpu_input.to(torch::kCUDA); + + std::vector output_tensors; + for (int i = 0; i < world_size_; i++) { + output_tensors.push_back(torch::empty({size}, cuda_options)); + } + + std::vector> output_tensor_lists = {output_tensors}; + std::vector input_tensors = {input}; + auto work = process_group_->allgather(output_tensor_lists, input_tensors); + work->wait(); + + for (int i = 0; i < world_size_; i++) { + auto cpu_output = output_tensors[i].to(torch::kCPU); + std::memcpy(static_cast(allData) + i * size, + cpu_output.data_ptr(), size); + } + DEBUG_PRINT("AllGather completed: %d bytes per rank", size); + } + + /** + * @brief Barrier synchronization + * + * Blocks until all ranks reach the barrier + */ + inline void barrier() { + auto work = process_group_->barrier(); + work->wait(); + DEBUG_PRINT("Barrier completed"); + } + + /** + * @brief Point-to-point send (optional, not needed for basic setup) + */ + inline void send(void* data, int size, int peer, int tag) { + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto tensor = torch::from_blob(static_cast(data), {size}, options); + std::vector tensors = {tensor}; + auto work = process_group_->send(tensors, peer, tag); + work->wait(); + } + + /** + * @brief Point-to-point receive (optional, not needed for basic setup) + */ + inline void recv(void* data, int size, int peer, int tag) { + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto tensor = torch::from_blob(static_cast(data), {size}, options); + std::vector tensors = {tensor}; + auto work = process_group_->recv(tensors, peer, tag); + work->wait(); + } + + private: + c10::intrusive_ptr process_group_; + int rank_; + int world_size_; +}; + +} // namespace iris_rdma + From 94f6164efb658348788a5fe1d64b8014b8d9c26f Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 12:05:45 -0500 Subject: [PATCH 02/16] Add setup.py --- setup.py | 99 +++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 69832461..59fa1192 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,108 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -from setuptools import setup +import os +import subprocess +import sys +from pathlib import Path +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext + + +class CMakeExtension(Extension): + """Extension that uses CMake to build""" + def __init__(self, name, sourcedir=""): + super().__init__(name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + """Custom build_ext command that runs CMake""" + + def run(self): + # Check if CMake is available + try: + subprocess.check_output(["cmake", "--version"]) + except OSError: + raise RuntimeError("CMake must be installed to build RDMA extensions") + + # Build each extension + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + if not isinstance(ext, CMakeExtension): + return super().build_extension(ext) + + extdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute() + + # CMake configuration arguments + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + "-DCMAKE_BUILD_TYPE=Release", + ] + + # Build arguments + build_args = ["--config", "Release"] + + # Parallel build + if hasattr(os, "cpu_count"): + build_args += [f"-j{os.cpu_count()}"] + + # Create build directory + build_temp = Path(self.build_temp) / ext.name + build_temp.mkdir(parents=True, exist_ok=True) + + # Run CMake + subprocess.check_call( + ["cmake", ext.sourcedir] + cmake_args, + cwd=build_temp + ) + + # Build + subprocess.check_call( + ["cmake", "--build", "."] + build_args, + cwd=build_temp + ) + + +# Check if InfiniBand libraries are available (optional RDMA support) +def has_infiniband(): + """Check if InfiniBand development libraries are available""" + try: + result = subprocess.run( + ["pkg-config", "--exists", "libibverbs"], + capture_output=True + ) + return result.returncode == 0 + except FileNotFoundError: + # pkg-config not available, try to find library directly + for path in ["/usr/lib", "/usr/lib64", "/usr/local/lib"]: + if os.path.exists(os.path.join(path, "libibverbs.so")): + return True + return False + + +# Build RDMA extension if InfiniBand is available +ext_modules = [] +if has_infiniband(): + print("InfiniBand libraries detected - building RDMA backend") + rdma_ext = CMakeExtension( + "iris.experimental._iris_rdma_backend", + sourcedir="iris/experimental/iris_rdma" + ) + ext_modules.append(rdma_ext) +else: + print("InfiniBand libraries not found - skipping RDMA backend") + print("To enable RDMA support, install: libibverbs-dev (Ubuntu/Debian) or rdma-core-devel (RHEL/CentOS)") + # This setup.py provides backward compatibility for legacy metadata fields # that don't map directly from pyproject.toml's modern PEP 621 format. setup( url="https://rocm.github.io/iris/", author="Muhammad Awad, Muhammad Osama, Brandon Potter", + ext_modules=ext_modules, + cmdclass={"build_ext": CMakeBuild} if ext_modules else {}, ) From 71c85f2269720f329feeb982cee1edc1c83a7ad2 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 12:06:20 -0500 Subject: [PATCH 03/16] Add RDMA iris backend --- iris/experimental/__init__.py | 34 ++- iris/experimental/iris_rdma.py | 386 +++++++++++++++++++++++++++++++++ 2 files changed, 419 insertions(+), 1 deletion(-) create mode 100644 iris/experimental/iris_rdma.py diff --git a/iris/experimental/__init__.py b/iris/experimental/__init__.py index dbab5167..44369673 100644 --- a/iris/experimental/__init__.py +++ b/iris/experimental/__init__.py @@ -9,8 +9,9 @@ Current experimental features: - iris_gluon: Gluon-based implementation using @aggregate and @gluon.jit +- iris_rdma: InfiniBand RDMA backend for multi-node communication -Usage: +Usage (Gluon): >>> import iris.experimental.iris_gluon as iris_gl >>> from triton.experimental import gluon >>> from triton.experimental.gluon import language as gl @@ -24,8 +25,39 @@ >>> def kernel(IrisDeviceCtx: gl.constexpr, context_tensor): >>> ctx = IrisDeviceCtx.initialize(context_tensor) >>> ctx.load(buffer, 1) + +Usage (RDMA): + >>> import iris.experimental.iris_rdma as iris_rdma + >>> import torch.distributed as dist + >>> + >>> # Initialize PyTorch Distributed first + >>> dist.init_process_group(backend='nccl') + >>> + >>> # Host side + >>> ctx = iris_rdma.iris(heap_size=2**30) + >>> device_ctx = ctx.get_device_context() + >>> + >>> # Device side + >>> @triton.jit + >>> def kernel(dst_ptr, data, device_ctx, dst_rank): + >>> iris_rdma.put(dst_ptr, data, dst_rank, device_ctx, mask) """ from . import iris_gluon +# Try to import iris_rdma (optional, requires InfiniBand) +try: + from . import iris_rdma + _has_rdma = True +except ImportError as e: + _has_rdma = False + import warnings + warnings.warn( + f"iris_rdma not available: {e}\n" + "InfiniBand RDMA support requires libibverbs-dev and building with CMake.", + ImportWarning + ) + __all__ = ["iris_gluon"] +if _has_rdma: + __all__.append("iris_rdma") diff --git a/iris/experimental/iris_rdma.py b/iris/experimental/iris_rdma.py new file mode 100644 index 00000000..c686e544 --- /dev/null +++ b/iris/experimental/iris_rdma.py @@ -0,0 +1,386 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Iris RDMA: Experimental InfiniBand RDMA Backend for Multi-Node Communication + +This module provides InfiniBand RDMA support for multi-node communication in Iris. +Unlike the main Iris which uses HIP IPC for intra-node GPU communication, this backend +enables inter-node communication via RDMA over InfiniBand. + +Key Features: +- InfiniBand Queue Pair (QP) setup and management +- Symmetric heap with RDMA memory registration +- RDMA put/get operations in Triton kernels +- PyTorch Distributed integration for bootstrapping + +Example: + >>> import iris.experimental.iris_rdma as iris_rdma + >>> import torch.distributed as dist + >>> + >>> # Initialize PyTorch Distributed + >>> dist.init_process_group(backend='nccl') + >>> + >>> # Create RDMA context + >>> ctx = iris_rdma.iris(heap_size=2**30) # 1GB heap + >>> device_ctx = ctx.get_device_context() # For passing to Triton kernels + >>> + >>> @triton.jit + >>> def kernel(dst_ptr, data, device_ctx, dst_rank, BLOCK_SIZE: tl.constexpr): + >>> pid = tl.program_id(0) + >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + >>> + >>> # RDMA put to remote rank + >>> iris_rdma.put(dst_ptr + offsets, data, dst_rank, device_ctx) +""" + +import torch +import torch.distributed as dist +import triton +import triton.language as tl +import numpy as np +import sys +import os + +# Import the C++ backend module +try: + from . import _iris_rdma_backend as backend +except ImportError: + raise ImportError( + "Iris RDMA backend not available. " + "Make sure the module is built with InfiniBand support. " + "Set IRIS_RDMA_DEBUG=1 for more information." + ) + +# Import logging +from ..logging import logger + + +class IrisRDMA: + """ + Main Iris RDMA class for multi-node RDMA operations. + + This class provides a unified interface for RDMA-based communication + across multiple nodes using InfiniBand. + + Args: + heap_size (int): Size of the symmetric heap in bytes. Default: 1GB (2^30) + process_group: PyTorch distributed process group (default: WORLD) + device_name (str): InfiniBand device name (default: auto-detect) + + Example: + >>> ctx = iris_rdma.iris(heap_size=2**31) # 2GB heap + >>> print(f"Rank {ctx.rank} of {ctx.world_size}") + >>> buffer = ctx.zeros(1024, dtype=torch.float32) + """ + + def __init__(self, heap_size=1 << 30, process_group=None, device_name=None): + # Check if distributed is initialized + if not dist.is_initialized(): + raise RuntimeError( + "PyTorch distributed must be initialized. " + "Call torch.distributed.init_process_group() first." + ) + + if process_group is None: + process_group = dist.group.WORLD + + # Get rank and world size + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.device_id = self.rank % torch.cuda.device_count() + self.device = f"cuda:{self.device_id}" + + torch.cuda.set_device(self.device_id) + + # Create TorchBootstrap + self._bootstrap = backend.TorchBootstrap(process_group) + + # Create NetworkBackend + self._backend = backend.NetworkBackend(self._bootstrap, device_name) + + # Initialize network (create QPs, transition to RTS) + self._backend.init() + + # Allocate symmetric heap (CPU pinned memory for now) + # TODO: Support GPU memory with GPUDirect RDMA + self.heap_size = heap_size + self.heap_offset = 0 + self.alignment = 1024 + + # Create CPU pinned memory pool + # For GPU memory, use: torch.empty(heap_size, device=self.device, dtype=torch.int8) + self.memory_pool = torch.empty(heap_size, device='cpu', dtype=torch.int8).pin_memory() + + # Register memory with RDMA + self._backend.register_memory(self.memory_pool) + + # Store remote heap bases (already exchanged in register_memory) + self.remote_heap_bases = [] + for i in range(self.world_size): + self.remote_heap_bases.append(self._backend.get_remote_heap_base(i)) + + logger.info(f"[Rank {self.rank}] Iris RDMA initialized: heap_size={heap_size}, " + f"heap_base={self._backend.get_heap_base():#x}") + + def get_device_context(self): + """ + Get device context tensor for passing to Triton kernels. + + The context tensor encodes: + - [0]: current rank + - [1]: world size + - [2:]: heap base addresses for all ranks + + Returns: + torch.Tensor: Device context tensor (on GPU) + + Example: + >>> ctx = iris_rdma.iris() + >>> device_ctx = ctx.get_device_context() + >>> # Pass device_ctx to Triton kernel + """ + # Create context tensor: [rank, world_size, heap_base_0, heap_base_1, ...] + context_size = 2 + self.world_size + context = torch.zeros(context_size, dtype=torch.int64, device=self.device) + + context[0] = self.rank + context[1] = self.world_size + + for i in range(self.world_size): + context[2 + i] = self.remote_heap_bases[i] + + return context + + def zeros(self, *size, dtype=torch.float32, device=None): + """ + Allocate and initialize a tensor with zeros in the symmetric heap. + + Args: + *size: Tensor dimensions + dtype: Data type (default: torch.float32) + device: Device placement ('cpu' or 'cuda', default: match context) + + Returns: + torch.Tensor: Allocated tensor + + Example: + >>> buffer = ctx.zeros(1024, 1024, dtype=torch.float32) + """ + if device is None: + device = 'cpu' # Use CPU for now (pinned memory) + + # Calculate size in bytes + elem_size = torch.tensor([], dtype=dtype).element_size() + numel = int(np.prod(size)) + size_bytes = numel * elem_size + + # Align allocation + aligned_offset = (self.heap_offset + self.alignment - 1) // self.alignment * self.alignment + + if aligned_offset + size_bytes > self.heap_size: + raise RuntimeError(f"Heap exhausted: requested {size_bytes} bytes, " + f"available {self.heap_size - aligned_offset}") + + # Create tensor view into memory pool + byte_offset = aligned_offset + byte_end = byte_offset + size_bytes + + # Get the memory slice and view as the requested dtype + memory_slice = self.memory_pool[byte_offset:byte_end] + tensor = memory_slice.view(dtype).reshape(size) + + # Zero initialize + tensor.zero_() + + # Update offset + self.heap_offset = byte_end + + logger.debug(f"[Rank {self.rank}] Allocated tensor: size={size}, " + f"offset={byte_offset:#x}, ptr={tensor.data_ptr():#x}") + + return tensor + + def barrier(self): + """ + Synchronize all ranks. + + Example: + >>> ctx.barrier() # Wait for all ranks + """ + dist.barrier() + + def rdma_put(self, dst_rank, local_addr, remote_addr, size): + """ + Perform RDMA write (put) to remote rank. + + Args: + dst_rank: Destination rank + local_addr: Local buffer address (int or tensor.data_ptr()) + remote_addr: Remote buffer address (int) + size: Size in bytes + + Returns: + int: 0 on success, non-zero on error + + Example: + >>> src = ctx.zeros(1024, dtype=torch.float32) + >>> dst_addr = ctx.remote_heap_bases[1] # Remote rank 1's heap + >>> ctx.rdma_put(1, src.data_ptr(), dst_addr, src.numel() * 4) + """ + if isinstance(local_addr, torch.Tensor): + local_addr = local_addr.data_ptr() + + return self._backend.rdma_write(dst_rank, local_addr, remote_addr, size) + + def rdma_get(self, dst_rank, local_addr, remote_addr, size): + """ + Perform RDMA read (get) from remote rank. + + Args: + dst_rank: Source rank (destination of the QP) + local_addr: Local buffer address (int or tensor.data_ptr()) + remote_addr: Remote buffer address (int) + size: Size in bytes + + Returns: + int: 0 on success, non-zero on error + + Example: + >>> dst = ctx.zeros(1024, dtype=torch.float32) + >>> src_addr = ctx.remote_heap_bases[1] # Remote rank 1's heap + >>> ctx.rdma_get(1, dst.data_ptr(), src_addr, dst.numel() * 4) + """ + if isinstance(local_addr, torch.Tensor): + local_addr = local_addr.data_ptr() + + return self._backend.rdma_read(dst_rank, local_addr, remote_addr, size) + + def poll_completion(self, dst_rank, max_completions=1): + """ + Poll completion queue for RDMA operations. + + Args: + dst_rank: Destination rank (to poll specific CQ) + max_completions: Maximum number of completions to poll + + Returns: + int: Number of completions polled (negative on error) + + Example: + >>> ctx.rdma_put(1, src.data_ptr(), remote_addr, size) + >>> while ctx.poll_completion(1) == 0: + >>> pass # Wait for completion + """ + return self._backend.poll_cq(dst_rank, max_completions) + + def __repr__(self): + return f"" + + +def iris(heap_size=1 << 30, process_group=None, device_name=None): + """ + Factory function to create Iris RDMA context. + + Args: + heap_size (int): Size of the symmetric heap in bytes + process_group: PyTorch distributed process group + device_name (str): InfiniBand device name (optional) + + Returns: + IrisRDMA: RDMA context object + + Example: + >>> import iris.experimental.iris_rdma as iris_rdma + >>> ctx = iris_rdma.iris(heap_size=2**30) + """ + return IrisRDMA(heap_size, process_group, device_name) + + +############################################################################# +# Triton Device-Side APIs +############################################################################# + +@triton.jit +def put(dst_ptr, data, dst_rank: tl.constexpr, device_ctx, mask): + """ + RDMA put (write) operation from Triton kernel. + + Writes data to remote rank's memory via RDMA. + + Args: + dst_ptr: Destination pointer (remote address) - can be block of pointers + data: Data values to write (block) + dst_rank: Target rank ID (must be compile-time constant) + device_ctx: Device context from iris_rdma.get_device_context() + mask: Triton mask for valid elements + + Example: + >>> @triton.jit + >>> def kernel(dst_ptr, src_ptr, device_ctx, dst_rank, BLOCK_SIZE: tl.constexpr): + >>> pid = tl.program_id(0) + >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> + >>> data = tl.load(src_ptr + offsets, mask=mask) + >>> iris_rdma.put(dst_ptr + offsets, data, dst_rank, device_ctx, mask) + """ + # Extract heap bases from device context + # Context format: [rank, world_size, heap_base_0, heap_base_1, ...] + dst_heap_base = tl.load(device_ctx + 2 + dst_rank) + + # For now, use tl.store as placeholder + # TODO: Implement actual RDMA put via queue or direct posting + # This will require either: + # 1. A device-side queue that CPU polls (like iris-rdma prototype) + # 2. Or direct ibv_post_send from GPU (requires GPU Direct Async) + + # Translate pointer to remote address space + # dst_ptr should already be in the remote address space + # Just store for now - in full implementation, this would queue RDMA request + tl.store(dst_ptr, data, mask=mask) + + +@triton.jit +def get(src_ptr, from_rank: tl.constexpr, device_ctx, mask): + """ + RDMA get (read) operation from Triton kernel. + + Reads data from remote rank's memory via RDMA. + + Args: + src_ptr: Source pointer (remote address) - can be block of pointers + from_rank: Source rank ID (must be compile-time constant) + device_ctx: Device context from iris_rdma.get_device_context() + mask: Triton mask for valid elements + + Returns: + Block of data read from remote rank + + Example: + >>> @triton.jit + >>> def kernel(dst_ptr, src_ptr, device_ctx, from_rank, BLOCK_SIZE: tl.constexpr): + >>> pid = tl.program_id(0) + >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> + >>> data = iris_rdma.get(src_ptr + offsets, from_rank, device_ctx, mask) + >>> tl.store(dst_ptr + offsets, data, mask=mask) + """ + # Extract heap bases from device context + src_heap_base = tl.load(device_ctx + 2 + from_rank) + + # For now, use tl.load as placeholder + # TODO: Implement actual RDMA get via queue or direct posting + data = tl.load(src_ptr, mask=mask) + + return data + + +__all__ = [ + "IrisRDMA", + "iris", + "put", + "get", +] + From 92e7bad8faa62d0b955035b307dbf71b5f029ca2 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 12:06:34 -0500 Subject: [PATCH 04/16] Add RDMA example --- examples/22_rdma_producer_consumer/README.md | 69 ++++++++ .../rdma_producer_consumer.py | 162 ++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 examples/22_rdma_producer_consumer/README.md create mode 100755 examples/22_rdma_producer_consumer/rdma_producer_consumer.py diff --git a/examples/22_rdma_producer_consumer/README.md b/examples/22_rdma_producer_consumer/README.md new file mode 100644 index 00000000..1d9139d4 --- /dev/null +++ b/examples/22_rdma_producer_consumer/README.md @@ -0,0 +1,69 @@ +# 22. RDMA Producer-Consumer + +Producer-consumer pattern using InfiniBand RDMA for multi-node communication. + +## Overview + +This example demonstrates: +- Producer Triton kernel generates data on Rank 0 +- RDMA transfer from Rank 0 to Rank 1 +- Consumer Triton kernel verifies data on Rank 1 + +## Requirements + +- InfiniBand network adapter +- libibverbs-dev installed +- Iris built with RDMA support + +## Architecture + +``` +Rank 0 (Producer) Rank 1 (Consumer) +───────────────── ───────────────── +producer_kernel() + ↓ writes +GPU → CPU buffer + ↓ +RDMA PUT ──────────────────→ CPU buffer + ↓ + CPU → GPU + ↓ + consumer_kernel() + ↓ verifies + ✓ Success +``` + +## Usage + +### Single Node (2 GPUs) +```bash +torchrun --nproc_per_node=2 rdma_producer_consumer.py +``` + +### Multi-Node (2 Nodes, 1 GPU each) +```bash +# Node 0 +torchrun --nnodes=2 --nproc_per_node=1 --node_rank=0 \ + --master_addr= --master_port=29500 \ + rdma_producer_consumer.py + +# Node 1 +torchrun --nnodes=2 --nproc_per_node=1 --node_rank=1 \ + --master_addr= --master_port=29500 \ + rdma_producer_consumer.py +``` + +## Expected Output + +``` +[Rank 0/2] Initialized on cuda:0 +[Rank 1/2] Initialized on cuda:1 +[Rank 0] Producing data +[Rank 0] First 10: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] +[Rank 0] RDMA transfer to Rank 1 +[Rank 0] RDMA completed +[Rank 1] Consuming data +[Rank 1] Received first 10: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] +[Rank 1] Verified: 4096/4096 +[Rank 1] SUCCESS! +``` diff --git a/examples/22_rdma_producer_consumer/rdma_producer_consumer.py b/examples/22_rdma_producer_consumer/rdma_producer_consumer.py new file mode 100755 index 00000000..c29f08c3 --- /dev/null +++ b/examples/22_rdma_producer_consumer/rdma_producer_consumer.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import os +import sys +import torch +import torch.distributed as dist +import triton +import triton.language as tl + +import iris.experimental.iris_rdma as iris_rdma + + +@triton.jit +def producer_kernel( + output_ptr, + n_elements, + rank_id, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + data = (rank_id * 1000 + offsets).to(tl.float32) + tl.store(output_ptr + offsets, data, mask=mask) + + +@triton.jit +def consumer_kernel( + input_ptr, + result_ptr, + n_elements, + expected_rank_id, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + data = tl.load(input_ptr + offsets, mask=mask, other=0.0) + expected = (expected_rank_id * 1000 + offsets).to(tl.float32) + is_correct = (data == expected).to(tl.float32) + + tl.store(result_ptr + offsets, is_correct, mask=mask) + + +def main(): + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + device_id = torch.device(f"cuda:{local_rank}") + + dist.init_process_group( + backend='nccl', + device_id=device_id + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + + if world_size < 2: + print("This example requires at least 2 ranks") + sys.exit(1) + + torch.cuda.set_device(local_rank) + device = f'cuda:{local_rank}' + + print(f"[Rank {rank}/{world_size}] Initialized on {device}") + + heap_size = 1024 * 1024 * 8 + ctx = iris_rdma.iris(heap_size=heap_size) + + print(f"[Rank {rank}] Iris RDMA initialized") + + n_elements = 4096 + BLOCK_SIZE = 256 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + local_buffer = ctx.zeros(n_elements, dtype=torch.float32) + + ctx.barrier() + + if rank == 0: + print(f"\n[Rank 0] Producing data") + + gpu_buffer = local_buffer.to(device) + + producer_kernel[grid]( + gpu_buffer, + n_elements, + rank_id=0, + BLOCK_SIZE=BLOCK_SIZE, + ) + + local_buffer.copy_(gpu_buffer.cpu()) + + print(f"[Rank 0] First 10: {local_buffer[:10].tolist()}") + + dst_rank = 1 + local_addr = local_buffer.data_ptr() + remote_addr = ctx.remote_heap_bases[dst_rank] + size = n_elements * 4 + + print(f"[Rank 0] RDMA transfer to Rank {dst_rank}") + + ret = ctx.rdma_put(dst_rank, local_addr, remote_addr, size) + + if ret == 0: + import time + for attempt in range(100): + n_comp = ctx.poll_completion(dst_rank) + if n_comp > 0: + print(f"[Rank 0] RDMA completed") + break + time.sleep(0.001) + + ctx.barrier() + + if rank == 1: + print(f"\n[Rank 1] Consuming data") + + gpu_buffer = local_buffer.to(device) + + print(f"[Rank 1] Received first 10: {local_buffer[:10].tolist()}") + + result_buffer = torch.zeros(n_elements, dtype=torch.float32, device=device) + + consumer_kernel[grid]( + gpu_buffer, + result_buffer, + n_elements, + expected_rank_id=0, + BLOCK_SIZE=BLOCK_SIZE, + ) + + result_cpu = result_buffer.cpu() + num_correct = result_cpu.sum().item() + num_total = n_elements + + print(f"[Rank 1] Verified: {int(num_correct)}/{num_total}") + + if num_correct == num_total: + print(f"[Rank 1] SUCCESS!") + else: + print(f"[Rank 1] FAILED") + sys.exit(1) + + ctx.barrier() + + if rank == 0: + print(f"\n{'='*60}") + print(f"RDMA Producer-Consumer Complete") + print(f"{'='*60}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + From fff5dd414906f9a88cfda1231b7aa0929800d542 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 12:06:44 -0500 Subject: [PATCH 05/16] Update docker --- docker/Dockerfile | 86 ++++++++++++++++++++++++++++------------------- docker/build.sh | 10 +++--- docker/run.sh | 48 ++++++++++++++++++++------ 3 files changed, 94 insertions(+), 50 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 8b49c01a..ba69c823 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,55 +1,71 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -FROM rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch +FROM rocm/pytorch:rocm7.0_ubuntu22.04_py3.10_pytorch_release_2.8.0 # Use bash shell for RUN commands SHELL ["/bin/bash", "-c"] # Set environment variables -ENV TRITON_PATH=/opt/triton \ - ROCM_PATH=/opt/rocm \ - OMPI_MCA_mtl="^ofi" \ - OMPI_MCA_pml="ob1" +ENV ROCM_PATH=/opt/rocm ENV LD_LIBRARY_PATH=$ROCM_PATH/lib:$LD_LIBRARY_PATH \ PATH="$ROCM_PATH/bin:$PATH" -ENV OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ - OMPI_ALLOW_RUN_AS_ROOT=1 - -# Install system packages +# Install system packages needed for Iris RDMA RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y \ - git wget ninja-build cmake python3-pip python3-dev build-essential && \ - rm -rf /var/lib/apt/lists/* + git wget cmake build-essential \ + libibverbs-dev librdmacm-dev \ + python3-pip python3-dev \ + infiniband-diags \ + perftest \ + && rm -rf /var/lib/apt/lists/* -# Install Python packages with pip +# Install Python packages RUN pip3 install --upgrade pip && \ - pip3 install wheel jupyter - -# Clone and install Triton -WORKDIR $TRITON_PATH -RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH -RUN git checkout dd5823453bcc7973eabadb65f9d827c43281c434 -RUN pip3 install -e . -ENV PYTHONPATH=$TRITON_PATH + pip3 install pybind11 -# Install rocprofiler-systems +# Set working directory WORKDIR /workspace -RUN wget https://github.com/ROCm/rocprofiler-systems/releases/download/rocm-6.3.1/rocprofiler-systems-install.py && \ - python3 ./rocprofiler-systems-install.py --prefix /opt/rocprofiler-systems --rocm 6.3 && \ - rm -f rocprofiler-systems-install.py # Create entrypoint script -RUN echo '#!/bin/bash' > /entrypoint.sh && \ - echo 'echo "Welcome to the ROCm-aware Docker image!"' >> /entrypoint.sh && \ - echo 'if [ $# -eq 0 ]; then' >> /entrypoint.sh && \ - echo ' exec /bin/bash' >> /entrypoint.sh && \ - echo 'else' >> /entrypoint.sh && \ - echo ' exec "$@"' >> /entrypoint.sh && \ - echo 'fi' >> /entrypoint.sh && \ - chmod +x /entrypoint.sh +RUN printf '#!/bin/bash\n\ +echo "=== Iris RDMA Development Environment ==="\n\ +echo "ROCm version: $(cat $ROCM_PATH/.info/version 2>/dev/null || echo unknown)"\n\ +echo "PyTorch version: $(python -c '\''import torch; print(torch.__version__)'\'' 2>/dev/null)"\n\ +\n\ +# GPU detection using PyTorch\n\ +python -c '\''\n\ +import torch\n\ +if torch.cuda.is_available():\n\ + count = torch.cuda.device_count()\n\ + print(f"GPUs available: {count}")\n\ + for i in range(count):\n\ + name = torch.cuda.get_device_name(i)\n\ + print(f" GPU[{i}]: {name}")\n\ +else:\n\ + print("GPUs available: 0")\n\ +'\'' 2>/dev/null || echo "GPUs available: 0"\n\ +\n\ +# InfiniBand detection\n\ +if [ -d /dev/infiniband ]; then\n\ + IB_COUNT=$(ls /dev/infiniband/uverbs* 2>/dev/null | wc -l)\n\ + echo "InfiniBand devices available: $IB_COUNT"\n\ + if [ $IB_COUNT -gt 0 ]; then\n\ + echo "InfiniBand device(s): $(ls /sys/class/infiniband/ 2>/dev/null | tr '\''\n'\'' '\'' '\'')"\n\ + fi\n\ +else\n\ + echo "InfiniBand devices available: 0"\n\ +fi\n\ +echo "======================================"\n\ +if [ $# -eq 0 ]; then\n\ + exec /bin/bash\n\ +else\n\ + exec "$@"\n\ +fi\n' > /entrypoint.sh + +RUN chmod +x /entrypoint.sh # Set the entrypoint -ENTRYPOINT ["/bin/bash", "-c", "source /entrypoint.sh && exec bash"] \ No newline at end of file +ENTRYPOINT ["/entrypoint.sh"] +CMD ["/bin/bash"] + diff --git a/docker/build.sh b/docker/build.sh index 973c9366..d86bf5a7 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -1,13 +1,13 @@ #!/bin/bash -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Build miniQP Docker image SCRIPT_DIR=$(dirname "$(realpath "$0")") - -IMAGE_NAME=${1:-"iris-dev"} +IMAGE_NAME=${1:-"iris-rdma"} pushd "$SCRIPT_DIR" > /dev/null -docker build -t $IMAGE_NAME . +echo "Building Docker image: $IMAGE_NAME" +docker build -t $IMAGE_NAME --network=host . popd > /dev/null + diff --git a/docker/run.sh b/docker/run.sh index c967875b..0864bab4 100755 --- a/docker/run.sh +++ b/docker/run.sh @@ -1,14 +1,42 @@ #!/bin/bash -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Run Iris RDMA Docker container with InfiniBand support +IMAGE_NAME=${1:-"iris-rdma"} +WORKSPACE_DIR=$(cd "$(dirname "$0")/.." && pwd) -IMAGE_NAME=${1:-"iris-dev"} -WORKSPACE_DIR=${2:-"$(pwd)"} +echo "Starting miniQP container..." +echo " Image: $IMAGE_NAME" +echo " Workspace: $WORKSPACE_DIR" + +# Auto-detect InfiniBand devices +IB_DEVICES="" +if [ -d /dev/infiniband ]; then + for dev in /dev/infiniband/uverbs*; do + if [ -e "$dev" ]; then + IB_DEVICES="$IB_DEVICES --device=$dev" + fi + done + if [ -n "$IB_DEVICES" ]; then + echo " InfiniBand devices: $(ls /dev/infiniband/uverbs* 2>/dev/null | wc -l) found" + fi +else + echo " Warning: No InfiniBand devices found" +fi +echo "" + +docker run -it --rm \ + --network=host \ + --device=/dev/kfd \ + --device=/dev/dri \ + $IB_DEVICES \ + --group-add video \ + --cap-add=SYS_PTRACE \ + --cap-add=IPC_LOCK \ + --security-opt seccomp=unconfined \ + -v "$WORKSPACE_DIR:$WORKSPACE_DIR" \ + -w "$WORKSPACE_DIR" \ + --shm-size=16G \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + $IMAGE_NAME -docker run -it --network=host --device=/dev/kfd\ - --device=/dev/dri --group-add video\ - --cap-add=SYS_PTRACE --security-opt seccomp=unconfined\ - -v "$WORKSPACE_DIR:$WORKSPACE_DIR" -w "$WORKSPACE_DIR"\ - --shm-size=16G --ulimit memlock=-1\ - --ulimit stack=67108864 $IMAGE_NAME From 8c872bea95a6e932e1213ed53478950627a6bc20 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 17:37:45 -0500 Subject: [PATCH 06/16] Add iris manager --- .../iris_rdma/python/bindings.cpp | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/iris/experimental/iris_rdma/python/bindings.cpp b/iris/experimental/iris_rdma/python/bindings.cpp index a50ea69e..6790741d 100644 --- a/iris/experimental/iris_rdma/python/bindings.cpp +++ b/iris/experimental/iris_rdma/python/bindings.cpp @@ -13,6 +13,7 @@ #include "network_backend.hpp" #include "queue_pair.hpp" #include "torch_bootstrap.hpp" +#include "iris_manager.hpp" namespace py = pybind11; using namespace iris_rdma; @@ -148,5 +149,69 @@ PYBIND11_MODULE(_iris_rdma_backend, m) { return ""; }); + + py::class_(m, "IrisManager") + .def(py::init([](std::shared_ptr bootstrap, py::object heap_tensor, int queue_size) { + // Extract heap pointer from tensor + if (!THPVariable_Check(heap_tensor.ptr())) { + throw std::runtime_error("heap_tensor must be a PyTorch tensor"); + } + auto heap = THPVariable_Unpack(heap_tensor.ptr()); + void* heap_ptr = heap.data_ptr(); + size_t heap_size = heap.numel() * heap.element_size(); + + return new iris::IrisManager(bootstrap, heap_ptr, heap_size, queue_size); + }), + py::arg("bootstrap"), py::arg("heap_tensor"), py::arg("queue_size") = 512, + "Create IrisManager with NetworkBackend + Queue + Proxy Thread") + .def("start_proxy_thread", &iris::IrisManager::startProxyThread, + "Start proxy thread that processes RDMA operations from queue") + .def("stop_proxy_thread", &iris::IrisManager::stopProxyThread, + "Stop proxy thread") + .def("get_queue_ptr", + [](iris::IrisManager& self) { + return reinterpret_cast(self.getQueuePtr()); + }, + "Get queue pointer for Triton kernels") + .def("get_heap_base", &iris::IrisManager::getHeapBase, + "Get local heap base address") + .def("get_remote_heap_base", &iris::IrisManager::getRemoteHeapBase, + py::arg("rank"), + "Get remote heap base address for a rank") + .def("get_rank", &iris::IrisManager::getRank, "Get rank") + .def("get_world_size", &iris::IrisManager::getWorldSize, "Get world size") + .def("is_queue_empty", &iris::IrisManager::isQueueEmpty, + "Check if queue is empty (all work items processed)") + .def("rdma_write", + [](iris::IrisManager& self, int dst_rank, uint64_t local_addr, + uint64_t remote_addr, size_t size, uint64_t wr_id) { + auto backend = self.getBackend(); + return backend->rdmaWrite(dst_rank, reinterpret_cast(local_addr), + remote_addr, size, wr_id); + }, + py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), + py::arg("size"), py::arg("wr_id") = 0, + "RDMA write to remote rank (local_addr is integer address)") + .def("rdma_read", + [](iris::IrisManager& self, int dst_rank, uint64_t local_addr, + uint64_t remote_addr, size_t size, uint64_t wr_id) { + auto backend = self.getBackend(); + return backend->rdmaRead(dst_rank, reinterpret_cast(local_addr), + remote_addr, size, wr_id); + }, + py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), + py::arg("size"), py::arg("wr_id") = 0, + "RDMA read from remote rank (local_addr is integer address)") + .def("poll_cq", + [](iris::IrisManager& self, int dst_rank, int max_completions) { + auto backend = self.getBackend(); + return backend->pollCQ(dst_rank, max_completions); + }, + py::arg("dst_rank"), py::arg("max_completions") = 1, + "Poll completion queue for RDMA operations") + .def("__repr__", [](const iris::IrisManager& mgr) { + return ""; + }); } From 5abb6f6253805aaee0ff719046982c3dedfab06c Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 17:38:03 -0500 Subject: [PATCH 07/16] Add iris manager --- .../iris_rdma/src/iris_manager.hpp | 246 ++++++++++++++++++ 1 file changed, 246 insertions(+) create mode 100644 iris/experimental/iris_rdma/src/iris_manager.hpp diff --git a/iris/experimental/iris_rdma/src/iris_manager.hpp b/iris/experimental/iris_rdma/src/iris_manager.hpp new file mode 100644 index 00000000..d4559cdd --- /dev/null +++ b/iris/experimental/iris_rdma/src/iris_manager.hpp @@ -0,0 +1,246 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +/** + * @file iris_manager.hpp + * @brief Complete Iris RDMA integration: Network + Queue + Proxy Thread + * + * Combines: + * - NetworkBackend (InfiniBand RDMA) + * - TritonDeviceQueue (GPU->CPU queue) + * - Proxy Thread (processes RDMA operations from queue) + */ + +#pragma once + +#include +#include +#include "network_backend.hpp" +#include "queue.hpp" + +namespace iris { + +/** + * @brief Complete Iris RDMA Manager + * + * Integration of NetworkBackend + TritonDeviceQueue + Proxy Thread + * Provides a unified interface for Triton kernels to perform RDMA operations + */ +class IrisManager { + public: + /** + * @brief Constructor + * @param bootstrap PyTorch bootstrap for distributed communication + * @param heap_base Pointer to symmetric heap + * @param heap_size Size of symmetric heap in bytes + * @param queue_size Queue capacity (default: 512) + */ + IrisManager(std::shared_ptr bootstrap, + void* heap_base, + size_t heap_size, + int queue_size = 512) + : heap_base_((uint64_t)heap_base), + heap_size_(heap_size), + running_(false) { + + // Step 1: Create NetworkBackend and initialize + backend_ = std::make_unique(bootstrap); + backend_->init(); + + // Step 2: Register symmetric heap (collective operation) + backend_->registerMemory(heap_base, heap_size); + + // Step 3: Create CPU-GPU queue + queue_ = std::make_unique(queue_size); + } + + ~IrisManager() { + if (running_) { + stopProxyThread(); + } + } + + /** + * @brief Start the proxy thread that processes RDMA operations + */ + void startProxyThread() { + if (running_) return; + running_ = true; + proxy_thread_ = std::thread(&IrisManager::proxyLoop, this); + } + + /** + * @brief Stop the proxy thread + */ + void stopProxyThread() { + running_ = false; + if (proxy_thread_.joinable()) { + proxy_thread_.join(); + } + } + + /** + * @brief Get the queue state pointer (for passing to Triton kernels) + */ + gpu_cpu_queue::QueueState* getQueuePtr() { + return queue_->getQueuePtr(); + } + + /** + * @brief Get heap base address + */ + uint64_t getHeapBase() { return heap_base_; } + + /** + * @brief Get the NetworkBackend (for direct RDMA operations) + */ + iris_rdma::NetworkBackend* getBackend() { return backend_.get(); } + + /** + * @brief Get remote heap base for a given rank + */ + uint64_t getRemoteHeapBase(int rank) { + return backend_->getRemoteHeapBase(rank); + } + + /** + * @brief Get rank + */ + int getRank() const { return backend_->getRank(); } + + /** + * @brief Get world size + */ + int getWorldSize() const { return backend_->getWorldSize(); } + + /** + * @brief Check if queue is empty (all work processed) + */ + bool isQueueEmpty() const { return queue_->isEmpty(); } + + private: + /** + * @brief Main proxy loop - processes RDMA operations from GPU queue + */ + void proxyLoop() { + gpu_cpu_queue::WorkItem item; + int checkCounter = 1000; + + while (true) { + // Check if should stop + if (checkCounter-- == 0) { + checkCounter = 1000; + if (!running_) break; + } + + // Poll for work from GPU queue + if (queue_->poll(item)) { + processWorkItem(item); + } + } + } + + /** + * @brief Process a single work item from the queue + */ + void processWorkItem(const gpu_cpu_queue::WorkItem& item) { + auto op_type = static_cast(item.header.op_type); + int dst_rank = item.header.rank; + + // Get addresses from queue metadata + uint64_t src_ptr = item.header.src_ptr; // Pointer/offset in registered heap + uint64_t dst_ptr = item.header.dst_ptr; // Remote destination + size_t size = item.header.size_bytes; + + switch (op_type) { + case gpu_cpu_queue::OperationType::PUT: { + // RDMA Write: Data is already in the registered heap at src_ptr + // No memcpy needed - just RDMA directly from heap! + void* local_addr = (void*)src_ptr; + + DEBUG_PRINT("[IrisManager] PUT: rank=%d src=%lx dst=%lx size=%zu", + dst_rank, src_ptr, dst_ptr, size); + + int ret = backend_->rdmaWrite(dst_rank, local_addr, dst_ptr, size); + if (ret != 0) { + fprintf(stderr, "[IrisManager] RDMA write failed: dst=%d size=%lu\n", dst_rank, size); + } else { + // Poll for completion + int n = 0; + for (int attempt = 0; attempt < 100; attempt++) { + n = backend_->pollCQ(dst_rank, 1); + if (n > 0) break; + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + DEBUG_PRINT("[IrisManager] Warning: PUT completion not polled (may be OK if async)"); + } + } + + // Signal completion + queue_->pop(); + break; + } + + case gpu_cpu_queue::OperationType::GET: { + // RDMA Read: Read from remote directly into registered heap at src_ptr + // GPU will read from heap after completion + void* local_addr = (void*)src_ptr; + + DEBUG_PRINT("[IrisManager] GET: rank=%d src=%lx dst=%lx size=%zu", + dst_rank, dst_ptr, src_ptr, size); + + int ret = backend_->rdmaRead(dst_rank, local_addr, dst_ptr, size); + if (ret != 0) { + fprintf(stderr, "[IrisManager] RDMA read failed: dst=%d size=%lu\n", dst_rank, size); + } else { + // Poll for completion + int n = 0; + for (int attempt = 0; attempt < 100; attempt++) { + n = backend_->pollCQ(dst_rank, 1); + if (n > 0) break; + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + DEBUG_PRINT("[IrisManager] Warning: GET completion not polled (may be OK if async)"); + } + } + + // Signal completion - GPU can now read from heap at src_ptr + queue_->pop(); + break; + } + + case gpu_cpu_queue::OperationType::FLUSH: { + // Flush all pending operations for this rank + DEBUG_PRINT("[IrisManager] FLUSH: rank=%d", dst_rank); + + int total = 0; + int n; + do { + n = backend_->pollCQ(dst_rank, 16); + if (n > 0) total += n; + } while (n > 0); + + queue_->pop(); + break; + } + + default: + fprintf(stderr, "[IrisManager] Unknown operation type: %d\n", item.header.op_type); + queue_->pop(); + } + } + + std::unique_ptr backend_; + std::unique_ptr queue_; + + uint64_t heap_base_; + size_t heap_size_; + + std::atomic running_; + std::thread proxy_thread_; +}; + +} // namespace iris + From 42b50fda904f659d576e1bc40ba3edf3ad07500f Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 17:38:21 -0500 Subject: [PATCH 08/16] Device-side enqueue --- iris/experimental/iris_rdma.py | 300 ++++++++++++++++++++++++++------- 1 file changed, 240 insertions(+), 60 deletions(-) diff --git a/iris/experimental/iris_rdma.py b/iris/experimental/iris_rdma.py index c686e544..b04ce019 100644 --- a/iris/experimental/iris_rdma.py +++ b/iris/experimental/iris_rdma.py @@ -74,7 +74,7 @@ class IrisRDMA: >>> buffer = ctx.zeros(1024, dtype=torch.float32) """ - def __init__(self, heap_size=1 << 30, process_group=None, device_name=None): + def __init__(self, heap_size=1 << 30, process_group=None, queue_size=512): # Check if distributed is initialized if not dist.is_initialized(): raise RuntimeError( @@ -96,32 +96,41 @@ def __init__(self, heap_size=1 << 30, process_group=None, device_name=None): # Create TorchBootstrap self._bootstrap = backend.TorchBootstrap(process_group) - # Create NetworkBackend - self._backend = backend.NetworkBackend(self._bootstrap, device_name) - - # Initialize network (create QPs, transition to RTS) - self._backend.init() - # Allocate symmetric heap (CPU pinned memory for now) # TODO: Support GPU memory with GPUDirect RDMA self.heap_size = heap_size self.heap_offset = 0 self.alignment = 1024 - # Create CPU pinned memory pool - # For GPU memory, use: torch.empty(heap_size, device=self.device, dtype=torch.int8) - self.memory_pool = torch.empty(heap_size, device='cpu', dtype=torch.int8).pin_memory() + # Create GPU memory pool + self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) + + self._manager = backend.IrisManager(self._bootstrap, self.memory_pool, queue_size) + self._manager.start_proxy_thread() + + self._backend = self._manager - # Register memory with RDMA - self._backend.register_memory(self.memory_pool) + logger.info(f"[Rank {self.rank}] Using IrisManager with queue (size={queue_size})") - # Store remote heap bases (already exchanged in register_memory) self.remote_heap_bases = [] for i in range(self.world_size): - self.remote_heap_bases.append(self._backend.get_remote_heap_base(i)) + self.remote_heap_bases.append(self._manager.get_remote_heap_base(i)) logger.info(f"[Rank {self.rank}] Iris RDMA initialized: heap_size={heap_size}, " - f"heap_base={self._backend.get_heap_base():#x}") + f"heap_base={self._manager.get_heap_base():#x}") + + def __del__(self): + """Clean up resources""" + if hasattr(self, '_manager') and self._manager is not None: + self._manager.stop_proxy_thread() + + def get_heap_base(self): + """Get local heap base address""" + return self._manager.get_heap_base() + + def get_queue_ptr(self): + """Get queue pointer for Triton kernels""" + return self._manager.get_queue_ptr() def get_device_context(self): """ @@ -130,7 +139,8 @@ def get_device_context(self): The context tensor encodes: - [0]: current rank - [1]: world size - - [2:]: heap base addresses for all ranks + - [2]: queue pointer (for enqueueing RDMA operations) + - [3:]: heap base addresses for all ranks Returns: torch.Tensor: Device context tensor (on GPU) @@ -140,15 +150,16 @@ def get_device_context(self): >>> device_ctx = ctx.get_device_context() >>> # Pass device_ctx to Triton kernel """ - # Create context tensor: [rank, world_size, heap_base_0, heap_base_1, ...] - context_size = 2 + self.world_size + # Create context tensor: [rank, world_size, queue_ptr, heap_base_0, heap_base_1, ...] + context_size = 3 + self.world_size context = torch.zeros(context_size, dtype=torch.int64, device=self.device) context[0] = self.rank context[1] = self.world_size + context[2] = self.get_queue_ptr() for i in range(self.world_size): - context[2 + i] = self.remote_heap_bases[i] + context[3 + i] = self.remote_heap_bases[i] return context @@ -159,16 +170,16 @@ def zeros(self, *size, dtype=torch.float32, device=None): Args: *size: Tensor dimensions dtype: Data type (default: torch.float32) - device: Device placement ('cpu' or 'cuda', default: match context) + device: Device placement (default: GPU for direct kernel access) Returns: - torch.Tensor: Allocated tensor + torch.Tensor: Allocated tensor (on GPU by default) Example: >>> buffer = ctx.zeros(1024, 1024, dtype=torch.float32) """ if device is None: - device = 'cpu' # Use CPU for now (pinned memory) + device = self.device # Use GPU by default (for GPUDirect) # Calculate size in bytes elem_size = torch.tensor([], dtype=dtype).element_size() @@ -203,13 +214,47 @@ def zeros(self, *size, dtype=torch.float32, device=None): def barrier(self): """ - Synchronize all ranks. + Synchronize all ranks and drain RDMA queue. + + Waits for: + 1. All enqueued RDMA operations to complete (queue drains) + 2. All ranks to reach this barrier Example: - >>> ctx.barrier() # Wait for all ranks + >>> ctx.barrier() # Wait for all ranks and RDMA completion """ + # First, wait for queue to drain (all work processed) + self.wait_queue_drain() + + # Then synchronize with other ranks dist.barrier() + def wait_queue_drain(self, timeout=30.0): + """ + Wait for the CPU proxy thread to process all enqueued work items. + + Spins until queue is empty (head == tail), meaning all work has been + processed and popped by the CPU proxy thread. + + Args: + timeout: Maximum time to wait in seconds + + Raises: + TimeoutError: If queue doesn't drain within timeout + """ + import time + start = time.time() + + while time.time() - start < timeout: + # Check if queue is empty (head == tail) + if self._manager.is_queue_empty(): + return + + # Small sleep to avoid burning CPU + time.sleep(0.0001) # 100 microseconds + + raise TimeoutError(f"Queue did not drain within {timeout}s") + def rdma_put(self, dst_rank, local_addr, remote_addr, size): """ Perform RDMA write (put) to remote rank. @@ -278,14 +323,14 @@ def __repr__(self): return f"" -def iris(heap_size=1 << 30, process_group=None, device_name=None): +def iris(heap_size=1 << 30, process_group=None, queue_size=512): """ Factory function to create Iris RDMA context. Args: heap_size (int): Size of the symmetric heap in bytes process_group: PyTorch distributed process group - device_name (str): InfiniBand device name (optional) + queue_size (int): Queue size for GPU->CPU RDMA operations Returns: IrisRDMA: RDMA context object @@ -294,7 +339,7 @@ def iris(heap_size=1 << 30, process_group=None, device_name=None): >>> import iris.experimental.iris_rdma as iris_rdma >>> ctx = iris_rdma.iris(heap_size=2**30) """ - return IrisRDMA(heap_size, process_group, device_name) + return IrisRDMA(heap_size, process_group, queue_size) ############################################################################# @@ -302,15 +347,146 @@ def iris(heap_size=1 << 30, process_group=None, device_name=None): ############################################################################# @triton.jit -def put(dst_ptr, data, dst_rank: tl.constexpr, device_ctx, mask): +def _wait_for_completion(queue_ptr, queue_pos): + """ + Wait for CPU to process a queue item. + + Spins until tail pointer advances past our queue position, + indicating the CPU has processed and popped our item. + + Args: + queue_ptr: Queue context pointer + queue_pos: Queue position to wait for (returned from _enqueue_rdma_op) + """ + state_ptr = queue_ptr.to(tl.pointer_type(tl.uint64)) + + # Load tail pointer (offset 2 in QueueState) + # Use volatile and cache modifier to prevent caching + tail_ptr = tl.load(state_ptr + 2, cache_modifier=".cv", volatile=True) + tail_ptr_typed = tail_ptr.to(tl.pointer_type(tl.uint64)) + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + # Spin until CPU advances tail past our position + while queue_pos >= current_tail: + tail_ptr = tl.load(state_ptr + 2, cache_modifier=".cv", volatile=True) + tail_ptr_typed = tail_ptr.to(tl.pointer_type(tl.uint64)) + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + +@triton.jit +def _enqueue_rdma_op(dst_ptr, src_ptr, to_rank: tl.constexpr, op_code: tl.constexpr, queue_ptr, mask): + """ + Internal: Enqueue an RDMA operation to the queue. + + Args: + dst_ptr: Destination pointer on remote rank + src_ptr: Source pointer (local address where data is stored in registered heap) + to_rank: Target rank ID + op_code: Operation type (1=PUT, 2=GET) + queue_ptr: Queue pointer from device context + mask: Triton mask for valid elements + """ + # Queue structure (from queue.hpp): + # struct QueueState { + # WorkItem* items; // offset 0 + # uint64_t* head; // offset 8 + # uint64_t* tail; // offset 16 + # uint64_t* tailCache; // offset 24 + # int32_t size; // offset 32 + # }; + + state_ptr = queue_ptr.to(tl.pointer_type(tl.uint64)) + + # Load QueueState fields + items_ptr = tl.load(state_ptr + 0) + head_ptr = tl.load(state_ptr + 1) + tail_ptr = tl.load(state_ptr + 2) + + # Load size (at offset 32 bytes = 4 * uint64) + size_ptr = queue_ptr.to(tl.pointer_type(tl.int32)) + size = tl.load(size_ptr + 8) + + # Atomic increment head to reserve slot + head_ptr_typed = head_ptr.to(tl.pointer_type(tl.uint64)) + prev_head = tl.atomic_add(head_ptr_typed, 1, sem='relaxed', scope='sys') + + # Wait for slot to be free: spin if prev_head >= size + *tail + size_u64 = size.to(tl.uint64) + tail_ptr_typed = tail_ptr.to(tl.pointer_type(tl.uint64)) + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + while prev_head >= size_u64 + current_tail: + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + # Calculate slot position + slot_idx = prev_head % size_u64 + + # WorkItem structure (32 bytes): + # struct WorkItem { + # uint64_t dst_ptr; // offset 0 + # uint64_t src_ptr; // offset 8 + # uint32_t size_bytes; // offset 16 - WRITE LAST as ready flag + # uint16_t rank; // offset 20 + # uint8_t op_type; // offset 22 + # uint8_t reserved; // offset 23 + # }; + WORK_ITEM_SIZE_BYTES = 32 + + slot_offset_bytes = slot_idx * WORK_ITEM_SIZE_BYTES + + # Get pointer to this work item + items_ptr_u64 = items_ptr.to(tl.pointer_type(tl.uint64)) + slot_ptr_u64 = items_ptr_u64 + (slot_offset_bytes // 8).to(tl.int32) + + # Extract destination address (min of pointer block) + dst_ptr_u64 = dst_ptr.to(tl.uint64) + dst_ptr_val = tl.min(dst_ptr_u64, axis=0) + + # Extract source address (min of pointer block where data is stored) + src_ptr_u64 = src_ptr.to(tl.uint64) + src_ptr_val = tl.min(src_ptr_u64, axis=0) + + # Calculate size in bytes from pointer range + # max_ptr - min_ptr gives us the byte distance to the last element + # Add element_size to include the last element itself + max_src_ptr = tl.max(src_ptr_u64, axis=0) + element_size_bytes = 4 # float32 + num_bytes = (max_src_ptr - src_ptr_val + element_size_bytes).to(tl.uint32) + size_bytes = num_bytes + + # Write header fields (but NOT size_bytes yet - it's the ready flag) + # Write dst_ptr (offset 0) + tl.store(slot_ptr_u64 + 0, dst_ptr_val) + + # Write src_ptr (offset 8) + tl.store(slot_ptr_u64 + 1, src_ptr_val) + + # Write rank + op_type (offset 20-23) + metadata = (to_rank & 0xFFFF) | ((op_code & 0xFF) << 16) + slot_ptr_u32 = slot_ptr_u64.to(tl.pointer_type(tl.uint32)) + tl.store(slot_ptr_u32 + 5, metadata.to(tl.uint32)) + + # Write size_bytes LAST as ready flag (offset 16) + size_bytes_ptr = (slot_ptr_u32 + 4).to(tl.pointer_type(tl.uint32)) + tl.atomic_xchg(size_bytes_ptr, size_bytes, sem='release', scope='sys') + + # Return queue position for waiting + return prev_head + + +@triton.jit +def put(dst_ptr, src_ptr, data, dst_rank: tl.constexpr, device_ctx, mask): """ RDMA put (write) operation from Triton kernel. - Writes data to remote rank's memory via RDMA. + Enqueues data to be written to remote rank via RDMA. + Data must first be stored in the registered heap at src_ptr location. + The CPU proxy thread will dequeue and perform the actual RDMA write. Args: dst_ptr: Destination pointer (remote address) - can be block of pointers - data: Data values to write (block) + src_ptr: Source pointer (local address in registered heap) - can be block of pointers + data: Data values to write (block) - will be stored at src_ptr dst_rank: Target rank ID (must be compile-time constant) device_ctx: Device context from iris_rdma.get_device_context() mask: Triton mask for valid elements @@ -322,59 +498,63 @@ def put(dst_ptr, data, dst_rank: tl.constexpr, device_ctx, mask): >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> - >>> data = tl.load(src_ptr + offsets, mask=mask) - >>> iris_rdma.put(dst_ptr + offsets, data, dst_rank, device_ctx, mask) + >>> data = generate_data(offsets) + >>> # Store data locally first (in registered heap) + >>> tl.store(src_ptr + offsets, data, mask=mask) + >>> # Enqueue RDMA operation + >>> iris_rdma.put(dst_ptr + offsets, src_ptr + offsets, data, dst_rank, device_ctx, mask) """ - # Extract heap bases from device context - # Context format: [rank, world_size, heap_base_0, heap_base_1, ...] - dst_heap_base = tl.load(device_ctx + 2 + dst_rank) - - # For now, use tl.store as placeholder - # TODO: Implement actual RDMA put via queue or direct posting - # This will require either: - # 1. A device-side queue that CPU polls (like iris-rdma prototype) - # 2. Or direct ibv_post_send from GPU (requires GPU Direct Async) - - # Translate pointer to remote address space - # dst_ptr should already be in the remote address space - # Just store for now - in full implementation, this would queue RDMA request - tl.store(dst_ptr, data, mask=mask) + # Extract queue pointer from device context + # Context format: [rank, world_size, queue_ptr, heap_base_0, heap_base_1, ...] + queue_ptr = tl.load(device_ctx + 2) + + # Store data in registered heap first + tl.store(src_ptr, data, mask=mask) + + # Enqueue PUT operation (op_code=1) + _enqueue_rdma_op(dst_ptr, src_ptr, dst_rank, 1, queue_ptr, mask) @triton.jit -def get(src_ptr, from_rank: tl.constexpr, device_ctx, mask): +def get(dst_ptr, src_ptr, from_rank: tl.constexpr, device_ctx, mask): """ RDMA get (read) operation from Triton kernel. - Reads data from remote rank's memory via RDMA. + Enqueues a request to read data from remote rank via RDMA and WAITS for completion. + The CPU proxy thread will dequeue, perform the RDMA read, then pop the item. + This function spins until the tail pointer advances, then data is ready at dst_ptr. Args: + dst_ptr: Local destination pointer where data will be written - can be block of pointers src_ptr: Source pointer (remote address) - can be block of pointers from_rank: Source rank ID (must be compile-time constant) device_ctx: Device context from iris_rdma.get_device_context() mask: Triton mask for valid elements - Returns: - Block of data read from remote rank - Example: >>> @triton.jit - >>> def kernel(dst_ptr, src_ptr, device_ctx, from_rank, BLOCK_SIZE: tl.constexpr): + >>> def kernel(local_ptr, remote_ptr, device_ctx, from_rank, BLOCK_SIZE: tl.constexpr): >>> pid = tl.program_id(0) >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> - >>> data = iris_rdma.get(src_ptr + offsets, from_rank, device_ctx, mask) - >>> tl.store(dst_ptr + offsets, data, mask=mask) + >>> # RDMA read from remote rank - blocks until complete + >>> iris_rdma.get(local_ptr + offsets, remote_ptr + offsets, from_rank, device_ctx, mask) + >>> + >>> # Data is now ready at local_ptr, can use it immediately + >>> data = tl.load(local_ptr + offsets, mask=mask) """ - # Extract heap bases from device context - src_heap_base = tl.load(device_ctx + 2 + from_rank) + # Extract queue pointer from device context + queue_ptr = tl.load(device_ctx + 2) + + # Enqueue GET operation (op_code=2) + # For GET: src_ptr is remote source, dst_ptr is local destination + queue_pos = _enqueue_rdma_op(src_ptr, dst_ptr, from_rank, 2, queue_ptr, mask) - # For now, use tl.load as placeholder - # TODO: Implement actual RDMA get via queue or direct posting - data = tl.load(src_ptr, mask=mask) + # Wait for CPU to complete the RDMA read + _wait_for_completion(queue_ptr, queue_pos) - return data + # Data is now ready at dst_ptr (CPU has written it there via RDMA) __all__ = [ From abbf8ec372117c782c8bec16e8212f4382f484f8 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 17:38:31 -0500 Subject: [PATCH 09/16] Add host-side queue --- iris/experimental/iris_rdma/src/queue.hpp | 232 ++++++++++++++++++++++ 1 file changed, 232 insertions(+) create mode 100644 iris/experimental/iris_rdma/src/queue.hpp diff --git a/iris/experimental/iris_rdma/src/queue.hpp b/iris/experimental/iris_rdma/src/queue.hpp new file mode 100644 index 00000000..dd2a5bb6 --- /dev/null +++ b/iris/experimental/iris_rdma/src/queue.hpp @@ -0,0 +1,232 @@ +// GPU-to-CPU Queue - C++ Host Side +// Exposes queue pointer to Python/Triton + +#ifndef QUEUE_HPP_ +#define QUEUE_HPP_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace gpu_cpu_queue { + +// Operation types - simplified for Iris +enum class OperationType : uint8_t { + NOP = 0, + PUT = 1, // RDMA write + GET = 2, // RDMA read + FLUSH = 3, // Flush connection +}; + +// Work item structure - metadata only, no data storage +// Data is stored in the registered symmetric heap +struct alignas(16) WorkItemHeader { + uint64_t dst_ptr; // Destination pointer (where to write on remote) + uint64_t src_ptr; // Source pointer (offset in local registered heap) + uint32_t size_bytes; // Size in bytes to transfer (WRITE LAST as ready flag) + uint16_t rank; // Remote rank + uint8_t op_type; // Operation type (see OperationType enum) + uint8_t reserved; // Reserved for future use +}; + +// WorkItem is now just the header - no data array needed! +// All data lives in the registered symmetric heap +// Note: Completion is signaled by tail pointer advancement, not a flag +struct alignas(16) WorkItem { + WorkItemHeader header; +}; + +// Queue state visible to both CPU and GPU +struct QueueState { + WorkItem* items; // Queue buffer (pinned host memory) + uint64_t* head; // Head pointer (device memory, GPU writes) + uint64_t* tail; // Tail pointer (host memory, CPU writes, GPU reads) + uint64_t* tailCache; // Cached tail (device memory) + int32_t size; // Queue capacity +}; + +// CPU-side queue management +class Queue { + public: + explicit Queue(int size = 512) : size_(size), running_(false) { + // Allocate pinned memory for QueueState struct (GPU needs to read this) + hipHostMalloc(&state_, sizeof(QueueState)); + + // Allocate pinned memory for queue items + hipHostMalloc(&state_->items, size * sizeof(WorkItem)); + memset(state_->items, 0, size * sizeof(WorkItem)); + + // Allocate device memory for head + hipMalloc(&state_->head, sizeof(uint64_t)); + hipMemset(state_->head, 0, sizeof(uint64_t)); + + // Allocate pinned memory for tail (CPU writes, GPU reads) + hipHostMalloc(&state_->tail, sizeof(uint64_t)); + *state_->tail = 0; + + // Allocate device memory for tail cache + hipMalloc(&state_->tailCache, sizeof(uint64_t)); + hipMemset(state_->tailCache, 0, sizeof(uint64_t)); + + state_->size = size; + } + + ~Queue() { + if (running_) { + stopProxy(); + } + hipHostFree(state_->items); + hipFree(state_->head); + hipHostFree(state_->tail); + hipFree(state_->tailCache); + hipHostFree(state_); + } + + // Get raw pointer to queue state for Triton + QueueState* getQueuePtr() { return state_; } + + // Poll for new work item (non-blocking) + bool poll(WorkItem& item) { + uint64_t currentTail = *state_->tail; + WorkItem* ptr = &state_->items[currentTail % size_]; + + // Atomic load of size_bytes (acquire semantics) - use as ready flag + // size_bytes == 0 means slot is empty/processed + uint32_t size_bytes = + reinterpret_cast*>(&ptr->header.size_bytes)->load(std::memory_order_acquire); + + // Check if slot is ready + if (size_bytes == 0) { + return false; // Queue empty + } + + // Copy entire work item (just header now, no data array) + memcpy(&item, ptr, sizeof(WorkItem)); + + return true; + } + + // Mark work item as processed + void pop() { + uint64_t currentTail = *state_->tail; + + // Clear the size_bytes to mark as processed + state_->items[currentTail % size_].header.size_bytes = 0; + + // Advance tail with release semantics (GPU will reload this into tailCache) + uint64_t newTail = currentTail + 1; + reinterpret_cast*>(state_->tail)->store(newTail, std::memory_order_release); + } + + // Start proxy thread + void startProxy() { + if (running_) return; + + running_ = true; + proxyThread_ = std::thread([this]() { this->proxyLoop(); }); + } + + void stopProxy() { + running_ = false; + if (proxyThread_.joinable()) { + proxyThread_.join(); + } + } + + // Get queue statistics + uint64_t getTail() const { return *state_->tail; } + + uint64_t getHead() const { + uint64_t h; + hipMemcpy(&h, state_->head, sizeof(uint64_t), hipMemcpyDeviceToHost); + return h; + } + + int getSize() const { return size_; } + + // Check if queue is empty (all work processed) + bool isEmpty() const { + uint64_t h; + hipMemcpy(&h, state_->head, sizeof(uint64_t), hipMemcpyDeviceToHost); + return h == *state_->tail; + } + + private: + void proxyLoop() { + WorkItem item; + int checkCounter = 1000; + + while (true) { + // Check if should stop + if (checkCounter-- == 0) { + checkCounter = 1000; + if (!running_) break; + } + + // Poll for work + if (poll(item)) { + // Process the work item: print to stdout (later: send to NIC) + OperationType op = static_cast(item.header.op_type); + + // Get operation name + const char* opName = "UNKNOWN"; + switch (op) { + case OperationType::NOP: + opName = "NOP"; + break; + case OperationType::PUT: + opName = "PUT"; + break; + case OperationType::GET: + opName = "GET"; + break; + case OperationType::FLUSH: + opName = "FLUSH"; + break; + } + + // Silent processing (uncomment to debug) + // std::cout << "[CPU Proxy] Op=" << opName << " (0x" << std::hex << (int)item.header.op_type << std::dec << ")" + // << " rank=" << item.header.rank << " dst=0x" << std::hex << item.header.dst_ptr << std::dec + // << " size=" << item.header.block_size << " first_values=["; + // for (int i = 0; i < std::min(5, (int)item.header.block_size); i++) { + // std::cout << item.data[i]; + // if (i < std::min(5, (int)item.header.block_size) - 1) std::cout << ", "; + // } + // std::cout << "...]" << std::endl; + + // Process based on operation type + if (op == OperationType::PUT) { + // TODO: Replace with actual NIC write + // nic->write(item.header.dst_ptr, item.data, item.header.block_size, item.header.rank); + pop(); + + } else if (op == OperationType::GET) { + std::cout << "[CPU Proxy] Processing GET operation - rank=" << item.header.rank + << " size=" << item.header.size_bytes << std::endl; + + // This is a test/debug proxy - IrisManager has the real proxy with RDMA + // Just mark as complete for testing + pop(); + std::cout << "[CPU Proxy] GET operation complete" << std::endl; + } + } + } + } + + QueueState* state_; + int size_; + std::atomic running_; + std::thread proxyThread_; +}; + +} // namespace gpu_cpu_queue + +#endif // QUEUE_HPP_ From c7269d72efc05a886d22ae034f2f0929591b27c6 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 17:40:30 -0500 Subject: [PATCH 10/16] Update example --- .../rdma_producer_consumer.py | 141 ++++++++++++------ 1 file changed, 98 insertions(+), 43 deletions(-) diff --git a/examples/22_rdma_producer_consumer/rdma_producer_consumer.py b/examples/22_rdma_producer_consumer/rdma_producer_consumer.py index c29f08c3..c626cfd8 100755 --- a/examples/22_rdma_producer_consumer/rdma_producer_consumer.py +++ b/examples/22_rdma_producer_consumer/rdma_producer_consumer.py @@ -8,24 +8,40 @@ import torch.distributed as dist import triton import triton.language as tl +import time import iris.experimental.iris_rdma as iris_rdma @triton.jit -def producer_kernel( - output_ptr, +def producer_put_kernel( + src_ptr, + dst_ptr, n_elements, rank_id, + dst_rank: tl.constexpr, + device_ctx, BLOCK_SIZE: tl.constexpr, ): + """ + Producer kernel that generates data and enqueues RDMA put operations. + """ pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements + # Generate data: rank_id * 1000 + offset data = (rank_id * 1000 + offsets).to(tl.float32) - tl.store(output_ptr + offsets, data, mask=mask) + + # src_ptr is a pointer to float32, adding offsets automatically scales by sizeof(float32) + src_ptrs = src_ptr + offsets + + # dst_ptr is an integer address, need to manually calculate byte offsets + dst_ptrs = dst_ptr + offsets * 4 # multiply by sizeof(float32) to get byte addresses + + # Enqueue RDMA put to remote rank + iris_rdma.put(dst_ptrs, src_ptrs, data, dst_rank, device_ctx, mask) @triton.jit @@ -36,12 +52,18 @@ def consumer_kernel( expected_rank_id, BLOCK_SIZE: tl.constexpr, ): + """ + Consumer kernel that verifies received data. + """ pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements + # Load received data data = tl.load(input_ptr + offsets, mask=mask, other=0.0) + + # Check if it matches expected pattern expected = (expected_rank_id * 1000 + offsets).to(tl.float32) is_correct = (data == expected).to(tl.float32) @@ -49,6 +71,7 @@ def consumer_kernel( def main(): + # Initialize distributed local_rank = int(os.environ.get('LOCAL_RANK', 0)) device_id = torch.device(f"cuda:{local_rank}") @@ -69,66 +92,97 @@ def main(): print(f"[Rank {rank}/{world_size}] Initialized on {device}") - heap_size = 1024 * 1024 * 8 - ctx = iris_rdma.iris(heap_size=heap_size) + # Create Iris RDMA context with queue + heap_size = 1024 * 1024 * 8 # 8MB + queue_size = 512 + ctx = iris_rdma.iris(heap_size=heap_size, queue_size=queue_size) print(f"[Rank {rank}] Iris RDMA initialized") + print(f"[Rank {rank}] - Heap base: {ctx.get_heap_base():#x}") + print(f"[Rank {rank}] - Queue ptr: {ctx.get_queue_ptr():#x}") + + # Get device context for Triton kernels + device_ctx = ctx.get_device_context() + # Allocate buffers in symmetric heap n_elements = 4096 BLOCK_SIZE = 256 grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - local_buffer = ctx.zeros(n_elements, dtype=torch.float32) + # Allocate in symmetric heap (GPU memory for GPUDirect RDMA) + local_buffer = ctx.zeros(n_elements, dtype=torch.float32) # Already on GPU + + # Move device_ctx to GPU + device_ctx_gpu = device_ctx.to(device) + + print(f"[Rank {rank}] Local buffer on device: {local_buffer.device}") ctx.barrier() + # ============================================================ + # PRODUCER (Rank 0): Generate data and RDMA put to Rank 1 + # ============================================================ if rank == 0: - print(f"\n[Rank 0] Producing data") + print(f"\n[Rank 0] === Producer: Generating and Sending Data ===") + + # Get remote heap address for rank 1 + dst_rank = 1 + remote_heap_base = ctx.remote_heap_bases[dst_rank] + + # local_buffer is already on GPU, just get its pointer + # Create pointer tensors on GPU + local_ptr = local_buffer.data_ptr() + remote_ptr = remote_heap_base - gpu_buffer = local_buffer.to(device) + print(f"[Rank 0] Launching Triton producer kernel") + print(f"[Rank 0] - Local ptr: {local_ptr:#x}") + print(f"[Rank 0] - Remote ptr: {remote_ptr:#x}") + print(f"[Rank 0] - Dst rank: {dst_rank}") - producer_kernel[grid]( - gpu_buffer, + # Launch producer kernel + # This will: + # 1. Generate data + # 2. Store locally (in registered GPU heap) + # 3. Enqueue RDMA put operations to queue + producer_put_kernel[grid]( + local_buffer, # Pass tensor directly (pointer will be extracted in kernel) + remote_ptr, n_elements, rank_id=0, + dst_rank=dst_rank, + device_ctx=device_ctx_gpu, BLOCK_SIZE=BLOCK_SIZE, ) - local_buffer.copy_(gpu_buffer.cpu()) + # Wait for GPU to finish enqueueing + torch.cuda.synchronize() + print(f"[Rank 0] ✓ Triton kernel completed (operations enqueued to queue)") + print(f"[Rank 0] Grid size was: {triton.cdiv(n_elements, BLOCK_SIZE)} programs") + print(f"[Rank 0] Each program should enqueue 1 work item") - print(f"[Rank 0] First 10: {local_buffer[:10].tolist()}") - - dst_rank = 1 - local_addr = local_buffer.data_ptr() - remote_addr = ctx.remote_heap_bases[dst_rank] - size = n_elements * 4 - - print(f"[Rank 0] RDMA transfer to Rank {dst_rank}") - - ret = ctx.rdma_put(dst_rank, local_addr, remote_addr, size) - - if ret == 0: - import time - for attempt in range(100): - n_comp = ctx.poll_completion(dst_rank) - if n_comp > 0: - print(f"[Rank 0] RDMA completed") - break - time.sleep(0.001) + # Show what we sent + print(f"[Rank 0] Sent data first 10: {local_buffer[:10].tolist()}") + # Barrier: waits for queue to drain AND all ranks to sync + # This ensures all RDMA operations have completed before proceeding + print(f"[Rank {rank}] Waiting at barrier for RDMA completion...") ctx.barrier() + print(f"[Rank {rank}] ✓ Barrier complete, all RDMA operations finished") + # ============================================================ + # CONSUMER (Rank 1): Verify received data + # ============================================================ if rank == 1: - print(f"\n[Rank 1] Consuming data") + print(f"\n[Rank 1] === Consumer: Verifying Received Data ===") - gpu_buffer = local_buffer.to(device) - - print(f"[Rank 1] Received first 10: {local_buffer[:10].tolist()}") + # Show received data (already on GPU) + print(f"[Rank 1] Received data first 10: {local_buffer[:10].tolist()}") + # Verify data (already on GPU) result_buffer = torch.zeros(n_elements, dtype=torch.float32, device=device) consumer_kernel[grid]( - gpu_buffer, + local_buffer, # Already on GPU result_buffer, n_elements, expected_rank_id=0, @@ -142,18 +196,19 @@ def main(): print(f"[Rank 1] Verified: {int(num_correct)}/{num_total}") if num_correct == num_total: - print(f"[Rank 1] SUCCESS!") + print(f"\n" + "="*60) + print(f"[Rank 1] ✓ SUCCESS! Data matches perfectly!") else: - print(f"[Rank 1] FAILED") + print(f"[Rank 1] ✗ FAILED - Data mismatch!") + first_wrong_idx = (result_cpu == 0).nonzero(as_tuple=True)[0] + if len(first_wrong_idx) > 0: + idx = first_wrong_idx[0].item() + print(f"[Rank 1] First wrong at index {idx}") + print(f"[Rank 1] Expected: {0 * 1000 + idx}") + print(f"[Rank 1] Got: {local_buffer[idx].item()}") sys.exit(1) ctx.barrier() - - if rank == 0: - print(f"\n{'='*60}") - print(f"RDMA Producer-Consumer Complete") - print(f"{'='*60}") - dist.destroy_process_group() From 72b5fbfde0fcd7df4f81d4de9c56486b7fb3b325 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 21:03:04 -0500 Subject: [PATCH 11/16] Fix bug --- .../rdma_producer_consumer.py | 100 +++++-------- iris/experimental/iris_rdma.py | 140 ++++++++++++++---- .../iris_rdma/src/iris_manager.hpp | 36 +++++ 3 files changed, 180 insertions(+), 96 deletions(-) diff --git a/examples/22_rdma_producer_consumer/rdma_producer_consumer.py b/examples/22_rdma_producer_consumer/rdma_producer_consumer.py index c626cfd8..d6e12363 100755 --- a/examples/22_rdma_producer_consumer/rdma_producer_consumer.py +++ b/examples/22_rdma_producer_consumer/rdma_producer_consumer.py @@ -15,33 +15,27 @@ @triton.jit def producer_put_kernel( - src_ptr, - dst_ptr, + buffer_ptr, n_elements, - rank_id, dst_rank: tl.constexpr, device_ctx, BLOCK_SIZE: tl.constexpr, ): """ - Producer kernel that generates data and enqueues RDMA put operations. + Producer kernel that enqueues RDMA put operations. + Data must already be in buffer_ptr (filled by fill_data_kernel). + Uses symmetric heap model: same buffer offset in local and remote heap. """ pid = tl.program_id(0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements - # Generate data: rank_id * 1000 + offset - data = (rank_id * 1000 + offsets).to(tl.float32) + # Src and dst are the same pointer + ptrs = buffer_ptr + offsets - # src_ptr is a pointer to float32, adding offsets automatically scales by sizeof(float32) - src_ptrs = src_ptr + offsets - - # dst_ptr is an integer address, need to manually calculate byte offsets - dst_ptrs = dst_ptr + offsets * 4 # multiply by sizeof(float32) to get byte addresses - - # Enqueue RDMA put to remote rank - iris_rdma.put(dst_ptrs, src_ptrs, data, dst_rank, device_ctx, mask) + # Enqueue RDMA operation + iris_rdma.put(ptrs, ptrs, dst_rank, device_ctx, mask) @triton.jit @@ -49,11 +43,11 @@ def consumer_kernel( input_ptr, result_ptr, n_elements, - expected_rank_id, BLOCK_SIZE: tl.constexpr, ): """ Consumer kernel that verifies received data. + Expected pattern: ascending numbers 0, 1, 2, ..., n_elements-1 """ pid = tl.program_id(0) block_start = pid * BLOCK_SIZE @@ -63,14 +57,17 @@ def consumer_kernel( # Load received data data = tl.load(input_ptr + offsets, mask=mask, other=0.0) - # Check if it matches expected pattern - expected = (expected_rank_id * 1000 + offsets).to(tl.float32) - is_correct = (data == expected).to(tl.float32) + # Check if it matches expected pattern (0, 1, 2, 3, ...) + expected = offsets.to(data.dtype) + is_correct = (data == expected).to(tl.int32) tl.store(result_ptr + offsets, is_correct, mask=mask) def main(): + + dtype = torch.bfloat16 + # Initialize distributed local_rank = int(os.environ.get('LOCAL_RANK', 0)) device_id = torch.device(f"cuda:{local_rank}") @@ -105,17 +102,12 @@ def main(): device_ctx = ctx.get_device_context() # Allocate buffers in symmetric heap - n_elements = 4096 + n_elements = 4091 BLOCK_SIZE = 256 grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - # Allocate in symmetric heap (GPU memory for GPUDirect RDMA) - local_buffer = ctx.zeros(n_elements, dtype=torch.float32) # Already on GPU - - # Move device_ctx to GPU - device_ctx_gpu = device_ctx.to(device) - - print(f"[Rank {rank}] Local buffer on device: {local_buffer.device}") + # Allocate on the symmetric heap + local_buffer = ctx.zeros(n_elements, dtype=dtype) ctx.barrier() @@ -124,50 +116,29 @@ def main(): # ============================================================ if rank == 0: print(f"\n[Rank 0] === Producer: Generating and Sending Data ===") - - # Get remote heap address for rank 1 dst_rank = 1 - remote_heap_base = ctx.remote_heap_bases[dst_rank] - # local_buffer is already on GPU, just get its pointer - # Create pointer tensors on GPU - local_ptr = local_buffer.data_ptr() - remote_ptr = remote_heap_base + # Step 1: Fill buffer with data using PyTorch (no race condition) + print(f"[Rank 0] Filling buffer with data using PyTorch...") + local_buffer.copy_(torch.arange(n_elements, dtype=dtype, device=device)) + print(f"[Rank 0] Data filled, first 10: {local_buffer[:10].tolist()}") - print(f"[Rank 0] Launching Triton producer kernel") - print(f"[Rank 0] - Local ptr: {local_ptr:#x}") - print(f"[Rank 0] - Remote ptr: {remote_ptr:#x}") - print(f"[Rank 0] - Dst rank: {dst_rank}") - - # Launch producer kernel - # This will: - # 1. Generate data - # 2. Store locally (in registered GPU heap) - # 3. Enqueue RDMA put operations to queue + # Step 2: Launch RDMA enqueue kernel (data already in memory) + print(f"[Rank 0] Launching RDMA enqueue kernel...") producer_put_kernel[grid]( - local_buffer, # Pass tensor directly (pointer will be extracted in kernel) - remote_ptr, + local_buffer, n_elements, - rank_id=0, dst_rank=dst_rank, - device_ctx=device_ctx_gpu, + device_ctx=device_ctx, BLOCK_SIZE=BLOCK_SIZE, ) # Wait for GPU to finish enqueueing torch.cuda.synchronize() - print(f"[Rank 0] ✓ Triton kernel completed (operations enqueued to queue)") - print(f"[Rank 0] Grid size was: {triton.cdiv(n_elements, BLOCK_SIZE)} programs") - print(f"[Rank 0] Each program should enqueue 1 work item") - - # Show what we sent - print(f"[Rank 0] Sent data first 10: {local_buffer[:10].tolist()}") + print(f"[Rank 0] RDMA operations enqueued to queue") - # Barrier: waits for queue to drain AND all ranks to sync - # This ensures all RDMA operations have completed before proceeding - print(f"[Rank {rank}] Waiting at barrier for RDMA completion...") ctx.barrier() - print(f"[Rank {rank}] ✓ Barrier complete, all RDMA operations finished") + print(f"[Rank {rank}] Barrier complete, all RDMA operations finished") # ============================================================ # CONSUMER (Rank 1): Verify received data @@ -175,17 +146,16 @@ def main(): if rank == 1: print(f"\n[Rank 1] === Consumer: Verifying Received Data ===") - # Show received data (already on GPU) + # Show received data print(f"[Rank 1] Received data first 10: {local_buffer[:10].tolist()}") - # Verify data (already on GPU) - result_buffer = torch.zeros(n_elements, dtype=torch.float32, device=device) + # Verify data (use int32 for result buffer - stores 0 or 1 for correctness) + result_buffer = torch.zeros(n_elements, dtype=torch.int32, device=device) consumer_kernel[grid]( - local_buffer, # Already on GPU + local_buffer, result_buffer, n_elements, - expected_rank_id=0, BLOCK_SIZE=BLOCK_SIZE, ) @@ -197,14 +167,14 @@ def main(): if num_correct == num_total: print(f"\n" + "="*60) - print(f"[Rank 1] ✓ SUCCESS! Data matches perfectly!") + print(f"[Rank 1] SUCCESS! Data matches perfectly!") else: - print(f"[Rank 1] ✗ FAILED - Data mismatch!") + print(f"[Rank 1] FAILED - Data mismatch!") first_wrong_idx = (result_cpu == 0).nonzero(as_tuple=True)[0] if len(first_wrong_idx) > 0: idx = first_wrong_idx[0].item() print(f"[Rank 1] First wrong at index {idx}") - print(f"[Rank 1] Expected: {0 * 1000 + idx}") + print(f"[Rank 1] Expected: {idx}") print(f"[Rank 1] Got: {local_buffer[idx].item()}") sys.exit(1) diff --git a/iris/experimental/iris_rdma.py b/iris/experimental/iris_rdma.py index b04ce019..6bb63358 100644 --- a/iris/experimental/iris_rdma.py +++ b/iris/experimental/iris_rdma.py @@ -346,6 +346,44 @@ def iris(heap_size=1 << 30, process_group=None, queue_size=512): # Triton Device-Side APIs ############################################################################# +@triton.jit +def _translate(ptr, from_rank, to_rank, heap_bases): + """ + Translate a pointer from one rank's address space to another. + + This implements the symmetric heap model where each rank has a heap at + a different base address, but offsets are preserved across ranks. + + Args: + ptr: Pointer in from_rank's address space + from_rank: Source rank ID + to_rank: Target rank ID + heap_bases: Pointer to array of heap base addresses + + Returns: + Translated pointer in to_rank's address space + """ + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + + # Convert to int to compute difference + ptr_int = ptr.to(tl.uint64) + + # Find the offset from from_rank heap + offset = ptr_int - from_base + + # Byte cast for byte offset addition + to_base_byte = to_base.to(tl.pointer_type(tl.int8)) + + # Find the offset into the to_rank heap + translated_ptr_byte = to_base_byte + offset + + # Cast back to original pointer type + translated_ptr = translated_ptr_byte.to(ptr.dtype) + + return translated_ptr + + @triton.jit def _wait_for_completion(queue_ptr, queue_pos): """ @@ -445,14 +483,35 @@ def _enqueue_rdma_op(dst_ptr, src_ptr, to_rank: tl.constexpr, op_code: tl.conste # Extract source address (min of pointer block where data is stored) src_ptr_u64 = src_ptr.to(tl.uint64) src_ptr_val = tl.min(src_ptr_u64, axis=0) - - # Calculate size in bytes from pointer range - # max_ptr - min_ptr gives us the byte distance to the last element - # Add element_size to include the last element itself max_src_ptr = tl.max(src_ptr_u64, axis=0) - element_size_bytes = 4 # float32 - num_bytes = (max_src_ptr - src_ptr_val + element_size_bytes).to(tl.uint32) - size_bytes = num_bytes + + # Infer element size from pointer type + # src_ptr is a block of pointers with a specific element type (e.g., pointer) + # The pointer dtype tells us the element type, which has a known size + # Map Triton dtypes to their byte sizes + ptr_dtype = src_ptr.dtype.element_ty # Get the element type that the pointer points to + + # Get element size in bytes from the dtype + # tl.float16 -> 2, tl.float32 -> 4, tl.float64 -> 8, etc. + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + element_size_bytes = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32 or ptr_dtype == tl.uint32: + element_size_bytes = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64 or ptr_dtype == tl.uint64: + element_size_bytes = 8 + elif ptr_dtype == tl.int8 or ptr_dtype == tl.uint8: + element_size_bytes = 1 + elif ptr_dtype == tl.int16 or ptr_dtype == tl.uint16: + element_size_bytes = 2 + else: + # Default to 4 bytes for unknown types + element_size_bytes = 4 + + # Calculate total size in bytes + # Count number of valid elements based on mask + mask_int = mask.to(tl.int32) + num_elements = tl.sum(mask_int, axis=0) + size_bytes = (num_elements * element_size_bytes).to(tl.uint32) # Write header fields (but NOT size_bytes yet - it's the ready flag) # Write dst_ptr (offset 0) @@ -475,44 +534,52 @@ def _enqueue_rdma_op(dst_ptr, src_ptr, to_rank: tl.constexpr, op_code: tl.conste @triton.jit -def put(dst_ptr, src_ptr, data, dst_rank: tl.constexpr, device_ctx, mask): +def put(dst_ptr, src_ptr, dst_rank: tl.constexpr, device_ctx, mask): """ RDMA put (write) operation from Triton kernel. - Enqueues data to be written to remote rank via RDMA. - Data must first be stored in the registered heap at src_ptr location. + Uses symmetric heap model: dst_ptr is in current rank's address space, + and will be automatically translated to remote rank's address space. + + IMPORTANT: Data must be stored at src_ptr BEFORE calling this function. + This avoids race conditions between GPU writes and CPU RDMA reads. The CPU proxy thread will dequeue and perform the actual RDMA write. Args: - dst_ptr: Destination pointer (remote address) - can be block of pointers - src_ptr: Source pointer (local address in registered heap) - can be block of pointers - data: Data values to write (block) - will be stored at src_ptr + dst_ptr: Destination pointer in CURRENT rank's address space (symmetric heap) + src_ptr: Source pointer (local address in registered heap) where data is already stored - can be block of pointers dst_rank: Target rank ID (must be compile-time constant) device_ctx: Device context from iris_rdma.get_device_context() mask: Triton mask for valid elements Example: >>> @triton.jit - >>> def kernel(dst_ptr, src_ptr, device_ctx, dst_rank, BLOCK_SIZE: tl.constexpr): + >>> def kernel(local_buffer, device_ctx, dst_rank, BLOCK_SIZE: tl.constexpr): >>> pid = tl.program_id(0) >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> >>> data = generate_data(offsets) - >>> # Store data locally first (in registered heap) - >>> tl.store(src_ptr + offsets, data, mask=mask) - >>> # Enqueue RDMA operation - >>> iris_rdma.put(dst_ptr + offsets, src_ptr + offsets, data, dst_rank, device_ctx, mask) + >>> src_ptrs = local_buffer + offsets + >>> dst_ptrs = local_buffer + offsets # Same offset, symmetric heap! + >>> + >>> # Store data FIRST to avoid race condition + >>> tl.store(src_ptrs, data, mask=mask) + >>> + >>> # Then enqueue RDMA operation + >>> iris_rdma.put(dst_ptrs, src_ptrs, dst_rank, device_ctx, mask) """ - # Extract queue pointer from device context + # Extract context fields # Context format: [rank, world_size, queue_ptr, heap_base_0, heap_base_1, ...] + my_rank = tl.load(device_ctx + 0) queue_ptr = tl.load(device_ctx + 2) + heap_bases = device_ctx + 3 - # Store data in registered heap first - tl.store(src_ptr, data, mask=mask) + # Translate dst_ptr from current rank's address space to remote rank's + translated_dst_ptr = _translate(dst_ptr, my_rank, dst_rank, heap_bases) - # Enqueue PUT operation (op_code=1) - _enqueue_rdma_op(dst_ptr, src_ptr, dst_rank, 1, queue_ptr, mask) + # Enqueue PUT operation (op_code=1) with translated address + _enqueue_rdma_op(translated_dst_ptr, src_ptr, dst_rank, 1, queue_ptr, mask) @triton.jit @@ -520,36 +587,47 @@ def get(dst_ptr, src_ptr, from_rank: tl.constexpr, device_ctx, mask): """ RDMA get (read) operation from Triton kernel. + Uses symmetric heap model: src_ptr is in current rank's address space, + and will be automatically translated to remote rank's address space. + Enqueues a request to read data from remote rank via RDMA and WAITS for completion. The CPU proxy thread will dequeue, perform the RDMA read, then pop the item. This function spins until the tail pointer advances, then data is ready at dst_ptr. Args: dst_ptr: Local destination pointer where data will be written - can be block of pointers - src_ptr: Source pointer (remote address) - can be block of pointers + src_ptr: Source pointer in CURRENT rank's address space (symmetric heap) from_rank: Source rank ID (must be compile-time constant) device_ctx: Device context from iris_rdma.get_device_context() mask: Triton mask for valid elements Example: >>> @triton.jit - >>> def kernel(local_ptr, remote_ptr, device_ctx, from_rank, BLOCK_SIZE: tl.constexpr): + >>> def kernel(local_buffer, device_ctx, from_rank, BLOCK_SIZE: tl.constexpr): >>> pid = tl.program_id(0) >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) >>> mask = offsets < n_elements >>> + >>> src_ptrs = local_buffer + offsets # Same offset in symmetric heap! + >>> dst_ptrs = local_buffer + offsets >>> # RDMA read from remote rank - blocks until complete - >>> iris_rdma.get(local_ptr + offsets, remote_ptr + offsets, from_rank, device_ctx, mask) + >>> iris_rdma.get(dst_ptrs, src_ptrs, from_rank, device_ctx, mask) >>> - >>> # Data is now ready at local_ptr, can use it immediately - >>> data = tl.load(local_ptr + offsets, mask=mask) + >>> # Data is now ready at dst_ptrs, can use it immediately + >>> data = tl.load(dst_ptrs, mask=mask) """ - # Extract queue pointer from device context + # Extract context fields + # Context format: [rank, world_size, queue_ptr, heap_base_0, heap_base_1, ...] + my_rank = tl.load(device_ctx + 0) queue_ptr = tl.load(device_ctx + 2) + heap_bases = device_ctx + 3 + + # Translate src_ptr from current rank's address space to remote rank's + translated_src_ptr = _translate(src_ptr, my_rank, from_rank, heap_bases) # Enqueue GET operation (op_code=2) - # For GET: src_ptr is remote source, dst_ptr is local destination - queue_pos = _enqueue_rdma_op(src_ptr, dst_ptr, from_rank, 2, queue_ptr, mask) + # For GET: translated_src_ptr is remote source, dst_ptr is local destination + queue_pos = _enqueue_rdma_op(translated_src_ptr, dst_ptr, from_rank, 2, queue_ptr, mask) # Wait for CPU to complete the RDMA read _wait_for_completion(queue_ptr, queue_pos) diff --git a/iris/experimental/iris_rdma/src/iris_manager.hpp b/iris/experimental/iris_rdma/src/iris_manager.hpp index d4559cdd..16dbe716 100644 --- a/iris/experimental/iris_rdma/src/iris_manager.hpp +++ b/iris/experimental/iris_rdma/src/iris_manager.hpp @@ -161,6 +161,42 @@ class IrisManager { DEBUG_PRINT("[IrisManager] PUT: rank=%d src=%lx dst=%lx size=%zu", dst_rank, src_ptr, dst_ptr, size); + // Debug: print first few values if environment variable is set + static bool debug_data = (getenv("IRIS_DEBUG_DATA") != nullptr); + static const char* dtype_env = getenv("IRIS_DTYPE"); + if (debug_data && size >= 4) { + // Determine element size and print accordingly + bool is_bf16 = (dtype_env && strcmp(dtype_env, "bfloat16") == 0); + bool is_fp16 = (dtype_env && strcmp(dtype_env, "float16") == 0); + bool is_fp32 = (!dtype_env || strcmp(dtype_env, "float32") == 0); + + if (is_bf16 || is_fp16) { + // 2-byte types - convert to float for display + int elem_count = std::min((int)(size / 2), 10); + uint16_t* data_ptr = (uint16_t*)local_addr; + fprintf(stderr, "[DEBUG-PUT] rank=%d dst=%d size=%zu (bf16) src=%lx dst=%lx: ", + backend_->getRank(), dst_rank, size, src_ptr, dst_ptr); + + for (int i = 0; i < elem_count; i++) { + // Convert bfloat16 to float: shift left 16 bits + uint32_t fp32_bits = ((uint32_t)data_ptr[i]) << 16; + float value = *reinterpret_cast(&fp32_bits); + fprintf(stderr, "%.1f ", value); + } + fprintf(stderr, "\n"); + } else if (is_fp32) { + // 4-byte float32 + int elem_count = std::min((int)(size / 4), 10); + float* float_ptr = (float*)local_addr; + fprintf(stderr, "[DEBUG-PUT] rank=%d dst=%d size=%zu (fp32) src=%lx dst=%lx: ", + backend_->getRank(), dst_rank, size, src_ptr, dst_ptr); + for (int i = 0; i < elem_count; i++) { + fprintf(stderr, "%.1f ", float_ptr[i]); + } + fprintf(stderr, "\n"); + } + } + int ret = backend_->rdmaWrite(dst_rank, local_addr, dst_ptr, size); if (ret != 0) { fprintf(stderr, "[IrisManager] RDMA write failed: dst=%d size=%lu\n", dst_rank, size); From a7a5333335fe989c3018acb7358a60ddcc4d88f1 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 21:20:22 -0500 Subject: [PATCH 12/16] Cleanup --- iris/experimental/iris_rdma.py | 10 +- .../iris_rdma/python/bindings.cpp | 5 - .../iris_rdma/src/iris_manager.hpp | 92 ++++++++++--------- iris/experimental/iris_rdma/src/queue.hpp | 87 +----------------- 4 files changed, 57 insertions(+), 137 deletions(-) diff --git a/iris/experimental/iris_rdma.py b/iris/experimental/iris_rdma.py index 6bb63358..ceb4347d 100644 --- a/iris/experimental/iris_rdma.py +++ b/iris/experimental/iris_rdma.py @@ -223,11 +223,15 @@ def barrier(self): Example: >>> ctx.barrier() # Wait for all ranks and RDMA completion """ - # First, wait for queue to drain (all work processed) - self.wait_queue_drain() - # Then synchronize with other ranks + # First, synchronize with all GPUs + torch.cuda.synchronize() + + # Then, synchronize with all ranks dist.barrier() + + # Finally, wait for queue to drain (all work processed) + self.wait_queue_drain() def wait_queue_drain(self, timeout=30.0): """ diff --git a/iris/experimental/iris_rdma/python/bindings.cpp b/iris/experimental/iris_rdma/python/bindings.cpp index 6790741d..1ed4e314 100644 --- a/iris/experimental/iris_rdma/python/bindings.cpp +++ b/iris/experimental/iris_rdma/python/bindings.cpp @@ -96,11 +96,6 @@ PYBIND11_MODULE(_iris_rdma_backend, m) { auto t = THPVariable_Unpack(obj.ptr()); ptr = t.data_ptr(); actual_size = t.numel() * t.element_size(); - - // Note: For GPU tensors, ibv_reg_mr will work if: - // 1. GPUDirect RDMA is enabled (check with ibstat/ibv_devinfo) - // 2. The memory is allocated with hipMalloc (native GPU memory) - // PyTorch tensors should work as they use hipMalloc internally } else { throw std::runtime_error("Expected a PyTorch tensor or integer address"); diff --git a/iris/experimental/iris_rdma/src/iris_manager.hpp b/iris/experimental/iris_rdma/src/iris_manager.hpp index 16dbe716..ce7cb745 100644 --- a/iris/experimental/iris_rdma/src/iris_manager.hpp +++ b/iris/experimental/iris_rdma/src/iris_manager.hpp @@ -124,15 +124,8 @@ class IrisManager { */ void proxyLoop() { gpu_cpu_queue::WorkItem item; - int checkCounter = 1000; - - while (true) { - // Check if should stop - if (checkCounter-- == 0) { - checkCounter = 1000; - if (!running_) break; - } + while (running_) { // Poll for work from GPU queue if (queue_->poll(item)) { processWorkItem(item); @@ -140,6 +133,53 @@ class IrisManager { } } + /** + * @brief Debug helper to print work item data + */ + void debugPrintWorkItem(const gpu_cpu_queue::WorkItem& item) { + static bool debug_enabled = (getenv("IRIS_DEBUG_DATA") != nullptr); + if (!debug_enabled || item.header.size_bytes < 4) return; + + // Extract info from work item + auto op_type = static_cast(item.header.op_type); + const char* op_name = (op_type == gpu_cpu_queue::OperationType::PUT) ? "PUT" : + (op_type == gpu_cpu_queue::OperationType::GET) ? "GET" : "OP"; + int dst_rank = item.header.rank; + uint64_t src_ptr = item.header.src_ptr; + uint64_t dst_ptr = item.header.dst_ptr; + size_t size = item.header.size_bytes; + void* data = (void*)src_ptr; + + static const char* dtype_env = getenv("IRIS_DTYPE"); + bool is_bf16 = (dtype_env && strcmp(dtype_env, "bfloat16") == 0); + bool is_fp16 = (dtype_env && strcmp(dtype_env, "float16") == 0); + bool is_fp32 = (!dtype_env || strcmp(dtype_env, "float32") == 0); + + fprintf(stderr, "[DEBUG-%s] rank=%d dst=%d size=%zu ", + op_name, backend_->getRank(), dst_rank, size); + + if (is_bf16 || is_fp16) { + // 2-byte types + int elem_count = std::min((int)(size / 2), 10); + uint16_t* data_ptr = (uint16_t*)data; + fprintf(stderr, "(bf16) src=%lx dst=%lx: ", src_ptr, dst_ptr); + for (int i = 0; i < elem_count; i++) { + uint32_t fp32_bits = ((uint32_t)data_ptr[i]) << 16; + float value = *reinterpret_cast(&fp32_bits); + fprintf(stderr, "%.1f ", value); + } + } else if (is_fp32) { + // 4-byte types + int elem_count = std::min((int)(size / 4), 10); + float* float_ptr = (float*)data; + fprintf(stderr, "(fp32) src=%lx dst=%lx: ", src_ptr, dst_ptr); + for (int i = 0; i < elem_count; i++) { + fprintf(stderr, "%.1f ", float_ptr[i]); + } + } + fprintf(stderr, "\n"); + } + /** * @brief Process a single work item from the queue */ @@ -161,41 +201,7 @@ class IrisManager { DEBUG_PRINT("[IrisManager] PUT: rank=%d src=%lx dst=%lx size=%zu", dst_rank, src_ptr, dst_ptr, size); - // Debug: print first few values if environment variable is set - static bool debug_data = (getenv("IRIS_DEBUG_DATA") != nullptr); - static const char* dtype_env = getenv("IRIS_DTYPE"); - if (debug_data && size >= 4) { - // Determine element size and print accordingly - bool is_bf16 = (dtype_env && strcmp(dtype_env, "bfloat16") == 0); - bool is_fp16 = (dtype_env && strcmp(dtype_env, "float16") == 0); - bool is_fp32 = (!dtype_env || strcmp(dtype_env, "float32") == 0); - - if (is_bf16 || is_fp16) { - // 2-byte types - convert to float for display - int elem_count = std::min((int)(size / 2), 10); - uint16_t* data_ptr = (uint16_t*)local_addr; - fprintf(stderr, "[DEBUG-PUT] rank=%d dst=%d size=%zu (bf16) src=%lx dst=%lx: ", - backend_->getRank(), dst_rank, size, src_ptr, dst_ptr); - - for (int i = 0; i < elem_count; i++) { - // Convert bfloat16 to float: shift left 16 bits - uint32_t fp32_bits = ((uint32_t)data_ptr[i]) << 16; - float value = *reinterpret_cast(&fp32_bits); - fprintf(stderr, "%.1f ", value); - } - fprintf(stderr, "\n"); - } else if (is_fp32) { - // 4-byte float32 - int elem_count = std::min((int)(size / 4), 10); - float* float_ptr = (float*)local_addr; - fprintf(stderr, "[DEBUG-PUT] rank=%d dst=%d size=%zu (fp32) src=%lx dst=%lx: ", - backend_->getRank(), dst_rank, size, src_ptr, dst_ptr); - for (int i = 0; i < elem_count; i++) { - fprintf(stderr, "%.1f ", float_ptr[i]); - } - fprintf(stderr, "\n"); - } - } + debugPrintWorkItem(item); int ret = backend_->rdmaWrite(dst_rank, local_addr, dst_ptr, size); if (ret != 0) { diff --git a/iris/experimental/iris_rdma/src/queue.hpp b/iris/experimental/iris_rdma/src/queue.hpp index dd2a5bb6..326335be 100644 --- a/iris/experimental/iris_rdma/src/queue.hpp +++ b/iris/experimental/iris_rdma/src/queue.hpp @@ -13,7 +13,6 @@ #include #include #include -#include namespace gpu_cpu_queue { @@ -36,8 +35,6 @@ struct alignas(16) WorkItemHeader { uint8_t reserved; // Reserved for future use }; -// WorkItem is now just the header - no data array needed! -// All data lives in the registered symmetric heap // Note: Completion is signaled by tail pointer advancement, not a flag struct alignas(16) WorkItem { WorkItemHeader header; @@ -55,7 +52,7 @@ struct QueueState { // CPU-side queue management class Queue { public: - explicit Queue(int size = 512) : size_(size), running_(false) { + explicit Queue(int size = 512) : size_(size) { // Allocate pinned memory for QueueState struct (GPU needs to read this) hipHostMalloc(&state_, sizeof(QueueState)); @@ -79,9 +76,6 @@ class Queue { } ~Queue() { - if (running_) { - stopProxy(); - } hipHostFree(state_->items); hipFree(state_->head); hipHostFree(state_->tail); @@ -125,21 +119,6 @@ class Queue { reinterpret_cast*>(state_->tail)->store(newTail, std::memory_order_release); } - // Start proxy thread - void startProxy() { - if (running_) return; - - running_ = true; - proxyThread_ = std::thread([this]() { this->proxyLoop(); }); - } - - void stopProxy() { - running_ = false; - if (proxyThread_.joinable()) { - proxyThread_.join(); - } - } - // Get queue statistics uint64_t getTail() const { return *state_->tail; } @@ -159,72 +138,8 @@ class Queue { } private: - void proxyLoop() { - WorkItem item; - int checkCounter = 1000; - - while (true) { - // Check if should stop - if (checkCounter-- == 0) { - checkCounter = 1000; - if (!running_) break; - } - - // Poll for work - if (poll(item)) { - // Process the work item: print to stdout (later: send to NIC) - OperationType op = static_cast(item.header.op_type); - - // Get operation name - const char* opName = "UNKNOWN"; - switch (op) { - case OperationType::NOP: - opName = "NOP"; - break; - case OperationType::PUT: - opName = "PUT"; - break; - case OperationType::GET: - opName = "GET"; - break; - case OperationType::FLUSH: - opName = "FLUSH"; - break; - } - - // Silent processing (uncomment to debug) - // std::cout << "[CPU Proxy] Op=" << opName << " (0x" << std::hex << (int)item.header.op_type << std::dec << ")" - // << " rank=" << item.header.rank << " dst=0x" << std::hex << item.header.dst_ptr << std::dec - // << " size=" << item.header.block_size << " first_values=["; - // for (int i = 0; i < std::min(5, (int)item.header.block_size); i++) { - // std::cout << item.data[i]; - // if (i < std::min(5, (int)item.header.block_size) - 1) std::cout << ", "; - // } - // std::cout << "...]" << std::endl; - - // Process based on operation type - if (op == OperationType::PUT) { - // TODO: Replace with actual NIC write - // nic->write(item.header.dst_ptr, item.data, item.header.block_size, item.header.rank); - pop(); - - } else if (op == OperationType::GET) { - std::cout << "[CPU Proxy] Processing GET operation - rank=" << item.header.rank - << " size=" << item.header.size_bytes << std::endl; - - // This is a test/debug proxy - IrisManager has the real proxy with RDMA - // Just mark as complete for testing - pop(); - std::cout << "[CPU Proxy] GET operation complete" << std::endl; - } - } - } - } - QueueState* state_; int size_; - std::atomic running_; - std::thread proxyThread_; }; } // namespace gpu_cpu_queue From 1fc0ac5e3f29c04ddb7b280677692d7cac7448f6 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 22:23:39 -0500 Subject: [PATCH 13/16] Cleanup case --- iris/experimental/iris_rdma.py | 8 +- .../iris_rdma/python/bindings.cpp | 155 ++++++++-------- iris/experimental/iris_rdma/src/ibv_utils.hpp | 12 +- .../iris_rdma/src/iris_manager.hpp | 94 +++++----- .../iris_rdma/src/network_backend.hpp | 170 +++++++++--------- iris/experimental/iris_rdma/src/queue.hpp | 57 +++--- .../experimental/iris_rdma/src/queue_pair.hpp | 40 ++--- .../iris_rdma/src/torch_bootstrap.hpp | 20 ++- 8 files changed, 279 insertions(+), 277 deletions(-) diff --git a/iris/experimental/iris_rdma.py b/iris/experimental/iris_rdma.py index ceb4347d..e82224f2 100644 --- a/iris/experimental/iris_rdma.py +++ b/iris/experimental/iris_rdma.py @@ -93,8 +93,8 @@ def __init__(self, heap_size=1 << 30, process_group=None, queue_size=512): torch.cuda.set_device(self.device_id) - # Create TorchBootstrap - self._bootstrap = backend.TorchBootstrap(process_group) + # Create torch_bootstrap + self._bootstrap = backend.torch_bootstrap(process_group) # Allocate symmetric heap (CPU pinned memory for now) # TODO: Support GPU memory with GPUDirect RDMA @@ -105,12 +105,12 @@ def __init__(self, heap_size=1 << 30, process_group=None, queue_size=512): # Create GPU memory pool self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) - self._manager = backend.IrisManager(self._bootstrap, self.memory_pool, queue_size) + self._manager = backend.rdma_proxy(self._bootstrap, self.memory_pool, queue_size) self._manager.start_proxy_thread() self._backend = self._manager - logger.info(f"[Rank {self.rank}] Using IrisManager with queue (size={queue_size})") + logger.info(f"[Rank {self.rank}] Using rdma_proxy with queue (size={queue_size})") self.remote_heap_bases = [] for i in range(self.world_size): diff --git a/iris/experimental/iris_rdma/python/bindings.cpp b/iris/experimental/iris_rdma/python/bindings.cpp index 1ed4e314..1af97958 100644 --- a/iris/experimental/iris_rdma/python/bindings.cpp +++ b/iris/experimental/iris_rdma/python/bindings.cpp @@ -16,70 +16,69 @@ #include "iris_manager.hpp" namespace py = pybind11; -using namespace iris_rdma; PYBIND11_MODULE(_iris_rdma_backend, m) { m.doc() = "Iris RDMA Backend: InfiniBand RDMA with PyTorch Integration"; // Expose NICVendor enum - py::enum_(m, "NICVendor") - .value("NONE", NICVendor::NONE) - .value("IONIC", NICVendor::IONIC) - .value("BNXT", NICVendor::BNXT) - .value("MLX5", NICVendor::MLX5) + py::enum_(m, "nic_vendor") + .value("NONE", iris::rdma::nic_vendor::NONE) + .value("IONIC", iris::rdma::nic_vendor::IONIC) + .value("BNXT", iris::rdma::nic_vendor::BNXT) + .value("MLX5", iris::rdma::nic_vendor::MLX5) .export_values(); - // Expose QPInfo struct - py::class_(m, "QPInfo") + // Expose qp_info_t struct + py::class_(m, "qp_info_t") .def(py::init<>()) - .def_readwrite("qp_num", &QPInfo::qp_num) - .def_readwrite("lkey", &QPInfo::lkey) - .def_readwrite("rkey", &QPInfo::rkey) - .def_readwrite("dst_rank", &QPInfo::dst_rank) - .def("__repr__", [](const QPInfo& info) { - return ""; }); - // Expose TorchBootstrap - py::class_>(m, - "TorchBootstrap") + // Expose torch_bootstrap + py::class_>(m, + "torch_bootstrap") .def(py::init([](py::object pg_obj) { // Extract c10d::ProcessGroup from Python object auto pg_ptr = pg_obj.cast>(); - return std::make_shared(pg_ptr); + return std::make_shared(pg_ptr); }), py::arg("process_group")) - .def("get_rank", &TorchBootstrap::getRank) - .def("get_world_size", &TorchBootstrap::getWorldSize) - .def("barrier", &TorchBootstrap::barrier); + .def("get_rank", &iris::rdma::torch_bootstrap::get_rank) + .def("get_world_size", &iris::rdma::torch_bootstrap::get_world_size) + .def("barrier", &iris::rdma::torch_bootstrap::barrier); - // Expose QueuePair (read-only access) - py::class_(m, "QueuePair") - .def("get_qp_num", &QueuePair::getQPNum) - .def("get_lkey", &QueuePair::getLKey) - .def("get_rkey", &QueuePair::getRKey) - .def("get_dst_rank", &QueuePair::getDstRank) - .def("get_info", &QueuePair::getInfo) - .def("__repr__", [](const QueuePair& qp) { - return ""; + // Expose queue_pair (read-only access) + py::class_(m, "queue_pair") + .def("get_qp_num", &iris::queue_pair::get_qp_num) + .def("get_lkey", &iris::queue_pair::get_lkey) + .def("get_rkey", &iris::queue_pair::get_rkey) + .def("get_dst_rank", &iris::queue_pair::get_dst_rank) + .def("get_info", &iris::queue_pair::get_info) + .def("__repr__", [](const iris::queue_pair& qp) { + return ""; }); - // Expose NetworkBackend - py::class_(m, "NetworkBackend") - .def(py::init, const char*>(), + // Expose network_backend + py::class_(m, "network_backend") + .def(py::init, const char*>(), py::arg("bootstrap"), py::arg("device_name") = nullptr, - "Create NetworkBackend with PyTorch bootstrap") - .def("init", &NetworkBackend::init, + "Create network_backend with PyTorch bootstrap") + .def("init", &iris::network_backend::init, "Initialize the network (setup QPs, transition to RTS)") .def( "register_memory", - [](NetworkBackend& self, py::object obj, size_t size = 0) { + [](iris::network_backend& self, py::object obj, size_t size = 0) { void* ptr = nullptr; size_t actual_size = size; @@ -101,52 +100,52 @@ PYBIND11_MODULE(_iris_rdma_backend, m) { throw std::runtime_error("Expected a PyTorch tensor or integer address"); } - self.registerMemory(ptr, actual_size); + self.register_memory(ptr, actual_size); }, py::arg("obj"), py::arg("size") = 0, "Register memory for RDMA (supports CPU pinned or GPU memory via GPUDirect)") - .def("get_qp", &NetworkBackend::getQP, py::arg("dst_rank"), + .def("get_qp", &iris::network_backend::get_qp, py::arg("dst_rank"), py::return_value_policy::reference_internal, "Get queue pair for destination rank") - .def("get_qp_info", &NetworkBackend::getQPInfo, py::arg("dst_rank"), + .def("get_qp_info", &iris::network_backend::get_qp_info, py::arg("dst_rank"), "Get QP info for destination rank") - .def("get_rank", &NetworkBackend::getRank, "Get rank") - .def("get_world_size", &NetworkBackend::getWorldSize, "Get world size") - .def("get_remote_heap_base", &NetworkBackend::getRemoteHeapBase, + .def("get_rank", &iris::network_backend::get_rank, "Get rank") + .def("get_world_size", &iris::network_backend::get_world_size, "Get world size") + .def("get_remote_heap_base", &iris::network_backend::get_remote_heap_base, py::arg("rank"), "Get remote heap base address for a rank") - .def("get_heap_base", &NetworkBackend::getHeapBase, + .def("get_heap_base", &iris::network_backend::get_heap_base, "Get local heap base address") - .def("get_heap_size", &NetworkBackend::getHeapSize, + .def("get_heap_size", &iris::network_backend::get_heap_size, "Get heap size in bytes") .def("rdma_write", - [](NetworkBackend& self, int dst_rank, uint64_t local_addr, + [](iris::network_backend& self, int dst_rank, uint64_t local_addr, uint64_t remote_addr, size_t size, uint64_t wr_id) { - return self.rdmaWrite(dst_rank, reinterpret_cast(local_addr), + return self.rdma_write(dst_rank, reinterpret_cast(local_addr), remote_addr, size, wr_id); }, py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), py::arg("size"), py::arg("wr_id") = 0, "RDMA write to remote rank (local_addr is integer address)") .def("rdma_read", - [](NetworkBackend& self, int dst_rank, uint64_t local_addr, + [](iris::network_backend& self, int dst_rank, uint64_t local_addr, uint64_t remote_addr, size_t size, uint64_t wr_id) { - return self.rdmaRead(dst_rank, reinterpret_cast(local_addr), + return self.rdma_read(dst_rank, reinterpret_cast(local_addr), remote_addr, size, wr_id); }, py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), py::arg("size"), py::arg("wr_id") = 0, "RDMA read from remote rank (local_addr is integer address)") - .def("poll_cq", &NetworkBackend::pollCQ, + .def("poll_cq", &iris::network_backend::poll_cq, py::arg("dst_rank"), py::arg("max_completions") = 1, "Poll completion queue for RDMA operations") - .def("__repr__", [](const NetworkBackend& backend) { - return ""; + .def("__repr__", [](const iris::network_backend& backend) { + return ""; }); - py::class_(m, "IrisManager") - .def(py::init([](std::shared_ptr bootstrap, py::object heap_tensor, int queue_size) { + py::class_(m, "rdma_proxy") + .def(py::init([](std::shared_ptr bootstrap, py::object heap_tensor, int queue_size) { // Extract heap pointer from tensor if (!THPVariable_Check(heap_tensor.ptr())) { throw std::runtime_error("heap_tensor must be a PyTorch tensor"); @@ -155,58 +154,58 @@ PYBIND11_MODULE(_iris_rdma_backend, m) { void* heap_ptr = heap.data_ptr(); size_t heap_size = heap.numel() * heap.element_size(); - return new iris::IrisManager(bootstrap, heap_ptr, heap_size, queue_size); + return new iris::rdma_proxy(bootstrap, heap_ptr, heap_size, queue_size); }), py::arg("bootstrap"), py::arg("heap_tensor"), py::arg("queue_size") = 512, - "Create IrisManager with NetworkBackend + Queue + Proxy Thread") - .def("start_proxy_thread", &iris::IrisManager::startProxyThread, + "Create rdma_proxy with network_backend + Queue + Proxy Thread") + .def("start_proxy_thread", &iris::rdma_proxy::start_proxy_thread, "Start proxy thread that processes RDMA operations from queue") - .def("stop_proxy_thread", &iris::IrisManager::stopProxyThread, + .def("stop_proxy_thread", &iris::rdma_proxy::stop_proxy_thread, "Stop proxy thread") .def("get_queue_ptr", - [](iris::IrisManager& self) { - return reinterpret_cast(self.getQueuePtr()); + [](iris::rdma_proxy& self) { + return reinterpret_cast(self.get_queue_ptr()); }, "Get queue pointer for Triton kernels") - .def("get_heap_base", &iris::IrisManager::getHeapBase, + .def("get_heap_base", &iris::rdma_proxy::get_heap_base, "Get local heap base address") - .def("get_remote_heap_base", &iris::IrisManager::getRemoteHeapBase, + .def("get_remote_heap_base", &iris::rdma_proxy::get_remote_heap_base, py::arg("rank"), "Get remote heap base address for a rank") - .def("get_rank", &iris::IrisManager::getRank, "Get rank") - .def("get_world_size", &iris::IrisManager::getWorldSize, "Get world size") - .def("is_queue_empty", &iris::IrisManager::isQueueEmpty, + .def("get_rank", &iris::rdma_proxy::get_rank, "Get rank") + .def("get_world_size", &iris::rdma_proxy::get_world_size, "Get world size") + .def("is_queue_empty", &iris::rdma_proxy::is_queue_empty, "Check if queue is empty (all work items processed)") .def("rdma_write", - [](iris::IrisManager& self, int dst_rank, uint64_t local_addr, + [](iris::rdma_proxy& self, int dst_rank, uint64_t local_addr, uint64_t remote_addr, size_t size, uint64_t wr_id) { - auto backend = self.getBackend(); - return backend->rdmaWrite(dst_rank, reinterpret_cast(local_addr), + auto backend = self.get_backend(); + return backend->rdma_write(dst_rank, reinterpret_cast(local_addr), remote_addr, size, wr_id); }, py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), py::arg("size"), py::arg("wr_id") = 0, "RDMA write to remote rank (local_addr is integer address)") .def("rdma_read", - [](iris::IrisManager& self, int dst_rank, uint64_t local_addr, + [](iris::rdma_proxy& self, int dst_rank, uint64_t local_addr, uint64_t remote_addr, size_t size, uint64_t wr_id) { - auto backend = self.getBackend(); - return backend->rdmaRead(dst_rank, reinterpret_cast(local_addr), + auto backend = self.get_backend(); + return backend->rdma_read(dst_rank, reinterpret_cast(local_addr), remote_addr, size, wr_id); }, py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), py::arg("size"), py::arg("wr_id") = 0, "RDMA read from remote rank (local_addr is integer address)") .def("poll_cq", - [](iris::IrisManager& self, int dst_rank, int max_completions) { - auto backend = self.getBackend(); - return backend->pollCQ(dst_rank, max_completions); + [](iris::rdma_proxy& self, int dst_rank, int max_completions) { + auto backend = self.get_backend(); + return backend->poll_cq(dst_rank, max_completions); }, py::arg("dst_rank"), py::arg("max_completions") = 1, "Poll completion queue for RDMA operations") - .def("__repr__", [](const iris::IrisManager& mgr) { - return ""; + .def("__repr__", [](const iris::rdma_proxy& mgr) { + return ""; }); } diff --git a/iris/experimental/iris_rdma/src/ibv_utils.hpp b/iris/experimental/iris_rdma/src/ibv_utils.hpp index ee55544f..b922572e 100644 --- a/iris/experimental/iris_rdma/src/ibv_utils.hpp +++ b/iris/experimental/iris_rdma/src/ibv_utils.hpp @@ -7,7 +7,8 @@ #include #include -namespace iris_rdma { +namespace iris { +namespace rdma { // Error checking macros #define CHECK_ZERO(expr, msg) \ @@ -36,10 +37,10 @@ namespace iris_rdma { } while (0) // Vendor detection -enum class NICVendor { NONE, IONIC, BNXT, MLX5 }; +enum class nic_vendor { NONE, IONIC, BNXT, MLX5 }; // QP destination info for connection -struct QPDestInfo { +struct qp_dest_info_t { int lid; int qpn; int psn; @@ -47,7 +48,7 @@ struct QPDestInfo { }; // QP metadata exposed to Python -struct QPInfo { +struct qp_info_t { uint32_t qp_num; uint32_t lkey; uint32_t rkey; @@ -90,5 +91,6 @@ inline int ibv_mtu_to_int(enum ibv_mtu mtu) { } } -} // namespace iris_rdma +} // namespace rdma +} // namespace iris diff --git a/iris/experimental/iris_rdma/src/iris_manager.hpp b/iris/experimental/iris_rdma/src/iris_manager.hpp index ce7cb745..bce4c2ec 100644 --- a/iris/experimental/iris_rdma/src/iris_manager.hpp +++ b/iris/experimental/iris_rdma/src/iris_manager.hpp @@ -21,12 +21,12 @@ namespace iris { /** - * @brief Complete Iris RDMA Manager + * @brief Complete Iris RDMA Proxy * - * Integration of NetworkBackend + TritonDeviceQueue + Proxy Thread + * Integration of network_backend + TritonDeviceQueue + Proxy Thread * Provides a unified interface for Triton kernels to perform RDMA operations */ -class IrisManager { +class rdma_proxy { public: /** * @brief Constructor @@ -35,44 +35,44 @@ class IrisManager { * @param heap_size Size of symmetric heap in bytes * @param queue_size Queue capacity (default: 512) */ - IrisManager(std::shared_ptr bootstrap, - void* heap_base, - size_t heap_size, - int queue_size = 512) + rdma_proxy(std::shared_ptr bootstrap, + void* heap_base, + size_t heap_size, + int queue_size = 512) : heap_base_((uint64_t)heap_base), heap_size_(heap_size), running_(false) { - // Step 1: Create NetworkBackend and initialize - backend_ = std::make_unique(bootstrap); + // Step 1: Create network_backend and initialize + backend_ = std::make_unique(bootstrap); backend_->init(); // Step 2: Register symmetric heap (collective operation) - backend_->registerMemory(heap_base, heap_size); + backend_->register_memory(heap_base, heap_size); // Step 3: Create CPU-GPU queue - queue_ = std::make_unique(queue_size); + queue_ = std::make_unique(queue_size); } - ~IrisManager() { + ~rdma_proxy() { if (running_) { - stopProxyThread(); + stop_proxy_thread(); } } /** * @brief Start the proxy thread that processes RDMA operations */ - void startProxyThread() { + void start_proxy_thread() { if (running_) return; running_ = true; - proxy_thread_ = std::thread(&IrisManager::proxyLoop, this); + proxy_thread_ = std::thread(&rdma_proxy::proxy_loop, this); } /** * @brief Stop the proxy thread */ - void stopProxyThread() { + void stop_proxy_thread() { running_ = false; if (proxy_thread_.joinable()) { proxy_thread_.join(); @@ -82,53 +82,53 @@ class IrisManager { /** * @brief Get the queue state pointer (for passing to Triton kernels) */ - gpu_cpu_queue::QueueState* getQueuePtr() { - return queue_->getQueuePtr(); + rdma::queue_state_t* get_queue_ptr() { + return queue_->get_queue_ptr(); } /** * @brief Get heap base address */ - uint64_t getHeapBase() { return heap_base_; } + uint64_t get_heap_base() { return heap_base_; } /** - * @brief Get the NetworkBackend (for direct RDMA operations) + * @brief Get the network_backend (for direct RDMA operations) */ - iris_rdma::NetworkBackend* getBackend() { return backend_.get(); } + network_backend* get_backend() { return backend_.get(); } /** * @brief Get remote heap base for a given rank */ - uint64_t getRemoteHeapBase(int rank) { - return backend_->getRemoteHeapBase(rank); + uint64_t get_remote_heap_base(int rank) { + return backend_->get_remote_heap_base(rank); } /** * @brief Get rank */ - int getRank() const { return backend_->getRank(); } + int get_rank() const { return backend_->get_rank(); } /** * @brief Get world size */ - int getWorldSize() const { return backend_->getWorldSize(); } + int get_world_size() const { return backend_->get_world_size(); } /** * @brief Check if queue is empty (all work processed) */ - bool isQueueEmpty() const { return queue_->isEmpty(); } + bool is_queue_empty() const { return queue_->is_empty(); } private: /** * @brief Main proxy loop - processes RDMA operations from GPU queue */ - void proxyLoop() { - gpu_cpu_queue::WorkItem item; + void proxy_loop() { + rdma::work_item_t item; while (running_) { // Poll for work from GPU queue if (queue_->poll(item)) { - processWorkItem(item); + process_work_item(item); } } } @@ -136,14 +136,14 @@ class IrisManager { /** * @brief Debug helper to print work item data */ - void debugPrintWorkItem(const gpu_cpu_queue::WorkItem& item) { + void debug_print_work_item(const rdma::work_item_t& item) { static bool debug_enabled = (getenv("IRIS_DEBUG_DATA") != nullptr); if (!debug_enabled || item.header.size_bytes < 4) return; // Extract info from work item - auto op_type = static_cast(item.header.op_type); - const char* op_name = (op_type == gpu_cpu_queue::OperationType::PUT) ? "PUT" : - (op_type == gpu_cpu_queue::OperationType::GET) ? "GET" : "OP"; + auto op_type = static_cast(item.header.op_type); + const char* op_name = (op_type == rdma::operation_type::PUT) ? "PUT" : + (op_type == rdma::operation_type::GET) ? "GET" : "OP"; int dst_rank = item.header.rank; uint64_t src_ptr = item.header.src_ptr; uint64_t dst_ptr = item.header.dst_ptr; @@ -156,7 +156,7 @@ class IrisManager { bool is_fp32 = (!dtype_env || strcmp(dtype_env, "float32") == 0); fprintf(stderr, "[DEBUG-%s] rank=%d dst=%d size=%zu ", - op_name, backend_->getRank(), dst_rank, size); + op_name, backend_->get_rank(), dst_rank, size); if (is_bf16 || is_fp16) { // 2-byte types @@ -183,8 +183,8 @@ class IrisManager { /** * @brief Process a single work item from the queue */ - void processWorkItem(const gpu_cpu_queue::WorkItem& item) { - auto op_type = static_cast(item.header.op_type); + void process_work_item(const rdma::work_item_t& item) { + auto op_type = static_cast(item.header.op_type); int dst_rank = item.header.rank; // Get addresses from queue metadata @@ -193,7 +193,7 @@ class IrisManager { size_t size = item.header.size_bytes; switch (op_type) { - case gpu_cpu_queue::OperationType::PUT: { + case rdma::operation_type::PUT: { // RDMA Write: Data is already in the registered heap at src_ptr // No memcpy needed - just RDMA directly from heap! void* local_addr = (void*)src_ptr; @@ -201,16 +201,16 @@ class IrisManager { DEBUG_PRINT("[IrisManager] PUT: rank=%d src=%lx dst=%lx size=%zu", dst_rank, src_ptr, dst_ptr, size); - debugPrintWorkItem(item); + debug_print_work_item(item); - int ret = backend_->rdmaWrite(dst_rank, local_addr, dst_ptr, size); + int ret = backend_->rdma_write(dst_rank, local_addr, dst_ptr, size); if (ret != 0) { fprintf(stderr, "[IrisManager] RDMA write failed: dst=%d size=%lu\n", dst_rank, size); } else { // Poll for completion int n = 0; for (int attempt = 0; attempt < 100; attempt++) { - n = backend_->pollCQ(dst_rank, 1); + n = backend_->poll_cq(dst_rank, 1); if (n > 0) break; std::this_thread::sleep_for(std::chrono::microseconds(10)); } @@ -224,7 +224,7 @@ class IrisManager { break; } - case gpu_cpu_queue::OperationType::GET: { + case rdma::operation_type::GET: { // RDMA Read: Read from remote directly into registered heap at src_ptr // GPU will read from heap after completion void* local_addr = (void*)src_ptr; @@ -232,14 +232,14 @@ class IrisManager { DEBUG_PRINT("[IrisManager] GET: rank=%d src=%lx dst=%lx size=%zu", dst_rank, dst_ptr, src_ptr, size); - int ret = backend_->rdmaRead(dst_rank, local_addr, dst_ptr, size); + int ret = backend_->rdma_read(dst_rank, local_addr, dst_ptr, size); if (ret != 0) { fprintf(stderr, "[IrisManager] RDMA read failed: dst=%d size=%lu\n", dst_rank, size); } else { // Poll for completion int n = 0; for (int attempt = 0; attempt < 100; attempt++) { - n = backend_->pollCQ(dst_rank, 1); + n = backend_->poll_cq(dst_rank, 1); if (n > 0) break; std::this_thread::sleep_for(std::chrono::microseconds(10)); } @@ -253,14 +253,14 @@ class IrisManager { break; } - case gpu_cpu_queue::OperationType::FLUSH: { + case rdma::operation_type::FLUSH: { // Flush all pending operations for this rank DEBUG_PRINT("[IrisManager] FLUSH: rank=%d", dst_rank); int total = 0; int n; do { - n = backend_->pollCQ(dst_rank, 16); + n = backend_->poll_cq(dst_rank, 16); if (n > 0) total += n; } while (n > 0); @@ -274,8 +274,8 @@ class IrisManager { } } - std::unique_ptr backend_; - std::unique_ptr queue_; + std::unique_ptr backend_; + std::unique_ptr queue_; uint64_t heap_base_; size_t heap_size_; diff --git a/iris/experimental/iris_rdma/src/network_backend.hpp b/iris/experimental/iris_rdma/src/network_backend.hpp index 8fee042e..03a450ab 100644 --- a/iris/experimental/iris_rdma/src/network_backend.hpp +++ b/iris/experimental/iris_rdma/src/network_backend.hpp @@ -24,7 +24,7 @@ #include #endif -namespace iris_rdma { +namespace iris { /** * @brief Main network backend for InfiniBand setup @@ -36,21 +36,21 @@ namespace iris_rdma { * - Memory registration * - QP connection info exchange */ -class NetworkBackend { +class network_backend { public: /** * @brief Constructor * @param bootstrap PyTorch bootstrap for cross-rank communication * @param device_name Optional device name (NULL for auto-detect) */ - NetworkBackend(std::shared_ptr bootstrap, - const char* device_name = nullptr) + network_backend(std::shared_ptr bootstrap, + const char* device_name = nullptr) : bootstrap_(bootstrap), requested_dev_(device_name), context_(nullptr), pd_orig_(nullptr), pd_parent_(nullptr), - vendor_(NICVendor::NONE), + vendor_(rdma::nic_vendor::NONE), port_(1), gid_index_(0), heap_mr_(nullptr), @@ -61,16 +61,16 @@ class NetworkBackend { if (!bootstrap_) { throw std::runtime_error("Bootstrap cannot be null"); } - rank_ = bootstrap_->getRank(); - world_size_ = bootstrap_->getWorldSize(); - DEBUG_PRINT("NetworkBackend created: rank=%d, world_size=%d", rank_, world_size_); + rank_ = bootstrap_->get_rank(); + world_size_ = bootstrap_->get_world_size(); + DEBUG_PRINT("network_backend created: rank=%d, world_size=%d", rank_, world_size_); } /** * @brief Destructor - cleanup InfiniBand resources */ - ~NetworkBackend() { - DEBUG_PRINT("NetworkBackend cleanup started"); + ~network_backend() { + DEBUG_PRINT("network_backend cleanup started"); qps_.clear(); @@ -118,15 +118,15 @@ class NetworkBackend { * @brief Initialize the network (setup QPs, transition to RTS) */ void init() { - DEBUG_PRINT("NetworkBackend::init() started"); + DEBUG_PRINT("network_backend::init() started"); - autodetectDVLibs(); - openIBDevice(); - createQueues(); - exchangeQPDestInfo(); - modifyQPsResetToInit(); - modifyQPsInitToRTR(); - modifyQPsRTRToRTS(); + autodetect_dv_libs(); + open_ib_device(); + create_queues(); + exchange_qp_dest_info(); + modify_qps_reset_to_init(); + modify_qps_init_to_rtr(); + modify_qps_rtr_to_rts(); bootstrap_->barrier(); DEBUG_PRINT("NetworkBackend::init() completed"); @@ -137,7 +137,7 @@ class NetworkBackend { * @param ptr Pointer to memory region * @param size Size in bytes */ - void registerMemory(void* ptr, size_t size) { + void register_memory(void* ptr, size_t size) { DEBUG_PRINT("Registering memory: ptr=%p, size=%zu", ptr, size); int access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | @@ -163,7 +163,7 @@ class NetworkBackend { rkeys_.resize(world_size_); std::vector all_rkeys(world_size_); all_rkeys[rank_] = heap_mr_->rkey; - bootstrap_->allGather(all_rkeys.data(), sizeof(uint32_t)); + bootstrap_->all_gather(all_rkeys.data(), sizeof(uint32_t)); for (int i = 0; i < world_size_; i++) { rkeys_[i] = all_rkeys[i]; } @@ -172,7 +172,7 @@ class NetworkBackend { remote_heap_bases_.resize(world_size_); std::vector all_heap_bases(world_size_); all_heap_bases[rank_] = heap_base_; - bootstrap_->allGather(all_heap_bases.data(), sizeof(uint64_t)); + bootstrap_->all_gather(all_heap_bases.data(), sizeof(uint64_t)); for (int i = 0; i < world_size_; i++) { remote_heap_bases_[i] = all_heap_bases[i]; } @@ -181,8 +181,8 @@ class NetworkBackend { uint32_t lkey = heap_mr_->lkey; for (int i = 0; i < world_size_; i++) { if (i < qps_.size() && qps_[i]) { - qps_[i]->setLKey(lkey); - qps_[i]->setRKey(rkeys_[i]); + qps_[i]->set_lkey(lkey); + qps_[i]->set_rkey(rkeys_[i]); } } @@ -195,7 +195,7 @@ class NetworkBackend { * @param dst_rank Destination rank * @return Pointer to QueuePair object */ - QueuePair* getQP(int dst_rank) { + queue_pair* get_qp(int dst_rank) { if (dst_rank >= 0 && dst_rank < qps_.size()) { return qps_[dst_rank].get(); } @@ -207,12 +207,12 @@ class NetworkBackend { * @param dst_rank Destination rank * @return QPInfo structure */ - QPInfo getQPInfo(int dst_rank) { - QueuePair* qp = getQP(dst_rank); + rdma::qp_info_t get_qp_info(int dst_rank) { + queue_pair* qp = get_qp(dst_rank); if (qp) { - return qp->getInfo(); + return qp->get_info(); } - return QPInfo{0, 0, 0, dst_rank}; + return rdma::qp_info_t{0, 0, 0, dst_rank}; } @@ -221,19 +221,19 @@ class NetworkBackend { /** * @brief Get rank */ - int getRank() const { return rank_; } + int get_rank() const { return rank_; } /** * @brief Get world size */ - int getWorldSize() const { return world_size_; } + int get_world_size() const { return world_size_; } /** * @brief Get remote heap base address for a rank * @param rank Remote rank * @return Remote heap base address (0 if not registered) */ - uint64_t getRemoteHeapBase(int rank) const { + uint64_t get_remote_heap_base(int rank) const { if (rank >= 0 && rank < remote_heap_bases_.size()) { return remote_heap_bases_[rank]; } @@ -244,13 +244,13 @@ class NetworkBackend { * @brief Get local heap base address * @return Local heap base address (0 if not registered) */ - uint64_t getHeapBase() const { return heap_base_; } + uint64_t get_heap_base() const { return heap_base_; } /** * @brief Get heap size * @return Heap size in bytes (0 if not registered) */ - size_t getHeapSize() const { return heap_size_; } + size_t get_heap_size() const { return heap_size_; } /** * @brief RDMA Write operation @@ -261,9 +261,9 @@ class NetworkBackend { * @param wr_id Work request ID (for completion tracking) * @return 0 on success, non-zero on error */ - int rdmaWrite(int dst_rank, void* local_addr, uint64_t remote_addr, + int rdma_write(int dst_rank, void* local_addr, uint64_t remote_addr, size_t size, uint64_t wr_id = 0) { - QueuePair* qp = getQP(dst_rank); + queue_pair* qp = get_qp(dst_rank); if (!qp) { return -1; } @@ -271,7 +271,7 @@ class NetworkBackend { struct ibv_sge sge; sge.addr = (uintptr_t)local_addr; sge.length = size; - sge.lkey = qp->getLKey(); + sge.lkey = qp->get_lkey(); struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); @@ -281,10 +281,10 @@ class NetworkBackend { wr.opcode = IBV_WR_RDMA_WRITE; wr.send_flags = IBV_SEND_SIGNALED; wr.wr.rdma.remote_addr = remote_addr; - wr.wr.rdma.rkey = qp->getRKey(); + wr.wr.rdma.rkey = qp->get_rkey(); struct ibv_send_wr* bad_wr; - int ret = ibv_post_send(qp->getIBVQP(), &wr, &bad_wr); + int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); DEBUG_PRINT("RDMA Write to rank %d: local=%p remote=%lx size=%zu ret=%d", dst_rank, local_addr, remote_addr, size, ret); @@ -301,9 +301,9 @@ class NetworkBackend { * @param wr_id Work request ID (for completion tracking) * @return 0 on success, non-zero on error */ - int rdmaRead(int dst_rank, void* local_addr, uint64_t remote_addr, + int rdma_read(int dst_rank, void* local_addr, uint64_t remote_addr, size_t size, uint64_t wr_id = 0) { - QueuePair* qp = getQP(dst_rank); + queue_pair* qp = get_qp(dst_rank); if (!qp) { return -1; } @@ -311,7 +311,7 @@ class NetworkBackend { struct ibv_sge sge; sge.addr = (uintptr_t)local_addr; sge.length = size; - sge.lkey = qp->getLKey(); + sge.lkey = qp->get_lkey(); struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); @@ -321,10 +321,10 @@ class NetworkBackend { wr.opcode = IBV_WR_RDMA_READ; wr.send_flags = IBV_SEND_SIGNALED; wr.wr.rdma.remote_addr = remote_addr; - wr.wr.rdma.rkey = qp->getRKey(); + wr.wr.rdma.rkey = qp->get_rkey(); struct ibv_send_wr* bad_wr; - int ret = ibv_post_send(qp->getIBVQP(), &wr, &bad_wr); + int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); DEBUG_PRINT("RDMA Read from rank %d: local=%p remote=%lx size=%zu ret=%d", dst_rank, local_addr, remote_addr, size, ret); @@ -338,15 +338,15 @@ class NetworkBackend { * @param max_completions Maximum number of completions to poll * @return Number of completions polled (negative on error) */ - int pollCQ(int dst_rank, int max_completions = 1) { - QueuePair* qp = getQP(dst_rank); + int poll_cq(int dst_rank, int max_completions = 1) { + queue_pair* qp = get_qp(dst_rank); if (!qp) { return -1; } struct ibv_wc wc[16]; int num_to_poll = (max_completions < 16) ? max_completions : 16; - int n = ibv_poll_cq(qp->getIBVCQ(), num_to_poll, wc); + int n = ibv_poll_cq(qp->get_ibv_cq(), num_to_poll, wc); if (n < 0) { DEBUG_PRINT("CQ poll error for rank %d", dst_rank); @@ -370,7 +370,7 @@ class NetworkBackend { private: // Bootstrap - std::shared_ptr bootstrap_; + std::shared_ptr bootstrap_; int rank_; int world_size_; @@ -379,7 +379,7 @@ class NetworkBackend { struct ibv_context* context_; struct ibv_pd* pd_orig_; struct ibv_pd* pd_parent_; // For MLX5/IONIC - NICVendor vendor_; + rdma::nic_vendor vendor_; // Port configuration struct ibv_port_attr portinfo_; @@ -395,9 +395,9 @@ class NetworkBackend { std::vector remote_heap_bases_; // Heap base addresses from all ranks // Queue pairs - std::vector> qps_; + std::vector> qps_; std::vector cqs_; - std::vector dest_info_; + std::vector dest_info_; // Dynamic library handles for vendor-specific libraries void* mlx5dv_handle_; @@ -406,29 +406,29 @@ class NetworkBackend { // Setup functions (extracted from rocSHMEM) // Vendor-specific init - void autodetectDVLibs() { + void autodetect_dv_libs() { DEBUG_PRINT("Auto-detecting vendor libraries..."); // Try MLX5 - if (mlx5DVDLInit() == 0) { - vendor_ = NICVendor::MLX5; + if (mlx5_dv_dl_init() == 0) { + vendor_ = rdma::nic_vendor::MLX5; DEBUG_PRINT("Detected MLX5 vendor"); return; } // Try BNXT - if (bnxtDVDLInit() == 0) { - vendor_ = NICVendor::BNXT; + if (bnxt_dv_dl_init() == 0) { + vendor_ = rdma::nic_vendor::BNXT; DEBUG_PRINT("Detected BNXT vendor"); return; } // Default to standard verbs - vendor_ = NICVendor::NONE; + vendor_ = rdma::nic_vendor::NONE; DEBUG_PRINT("Using standard InfiniBand verbs"); } - int mlx5DVDLInit() { + int mlx5_dv_dl_init() { mlx5dv_handle_ = dlopen("libmlx5.so", RTLD_NOW); if (!mlx5dv_handle_) { mlx5dv_handle_ = dlopen("libmlx5.so.1", RTLD_NOW); @@ -442,7 +442,7 @@ class NetworkBackend { return 0; } - int bnxtDVDLInit() { + int bnxt_dv_dl_init() { bnxtdv_handle_ = dlopen("libbnxt_re.so", RTLD_NOW); if (!bnxtdv_handle_) { bnxtdv_handle_ = dlopen("/usr/local/lib/libbnxt_re.so", RTLD_NOW); @@ -456,7 +456,7 @@ class NetworkBackend { return 0; } - void openIBDevice() { + void open_ib_device() { DEBUG_PRINT("Opening InfiniBand device..."); struct ibv_device** device_list = nullptr; @@ -488,26 +488,26 @@ class NetworkBackend { // Open device context_ = ibv_open_device(device); CHECK_NNULL(context_, "ibv_open_device"); - dump_ibv_context(context_); - dump_ibv_device(context_->device); + rdma::dump_ibv_context(context_); + rdma::dump_ibv_device(context_->device); // Allocate protection domain pd_orig_ = ibv_alloc_pd(context_); CHECK_NNULL(pd_orig_, "ibv_alloc_pd"); - dump_ibv_pd(pd_orig_); + rdma::dump_ibv_pd(pd_orig_); // Create parent domain for MLX5/IONIC - if (vendor_ == NICVendor::MLX5) { - createParentDomain(); + if (vendor_ == rdma::nic_vendor::MLX5) { + create_parent_domain(); } // Query port int err = ibv_query_port(context_, port_, &portinfo_); CHECK_ZERO(err, "ibv_query_port"); - dump_ibv_port_attr(&portinfo_); + rdma::dump_ibv_port_attr(&portinfo_); // Select GID index - selectGIDIndex(); + select_gid_index(); ibv_free_device_list(device_list); @@ -515,7 +515,7 @@ class NetworkBackend { ibv_get_device_name(context_->device)); } - void createParentDomain() { + void create_parent_domain() { DEBUG_PRINT("Creating parent domain..."); struct ibv_parent_domain_init_attr pattr; @@ -527,10 +527,10 @@ class NetworkBackend { pd_parent_ = ibv_alloc_parent_domain(context_, &pattr); CHECK_NNULL(pd_parent_, "ibv_alloc_parent_domain"); - dump_ibv_pd(pd_parent_); + rdma::dump_ibv_pd(pd_parent_); } - void selectGIDIndex() { + void select_gid_index() { DEBUG_PRINT("Selecting GID index..."); const uint8_t local_gid_prefix[2] = {0xFE, 0x80}; @@ -571,7 +571,7 @@ class NetworkBackend { DEBUG_PRINT("Selected GID index: %d", gid_index_); } - void createQueues() { + void create_queues() { DEBUG_PRINT("Creating queues..."); int ncqes = 64; // Number of CQ entries @@ -583,13 +583,13 @@ class NetworkBackend { qps_.resize(world_size_); // Create CQs and QPs - createCQs(ncqes); - createQPs(sq_length); + create_cqs(ncqes); + create_qps(sq_length); DEBUG_PRINT("Created %d queue pairs", world_size_); } - void createCQs(int ncqes) { + void create_cqs(int ncqes) { DEBUG_PRINT("Creating completion queues: ncqes=%d", ncqes); struct ibv_cq_init_attr_ex cq_attr; @@ -615,7 +615,7 @@ class NetworkBackend { } } - void createQPs(int sq_length) { + void create_qps(int sq_length) { DEBUG_PRINT("Creating queue pairs: sq_length=%d", sq_length); struct ibv_qp_init_attr_ex attr; @@ -636,28 +636,28 @@ class NetworkBackend { struct ibv_qp* qp = ibv_create_qp_ex(context_, &attr); CHECK_NNULL(qp, "ibv_create_qp_ex"); - qps_[i] = std::make_unique(qp, cqs_[i], i, vendor_); + qps_[i] = std::make_unique(qp, cqs_[i], i, vendor_); } } - void exchangeQPDestInfo() { + void exchange_qp_dest_info() { DEBUG_PRINT("Exchanging QP destination info..."); // Fill local dest info for (int i = 0; i < world_size_; i++) { dest_info_[i].lid = portinfo_.lid; - dest_info_[i].qpn = qps_[i]->getQPNum(); + dest_info_[i].qpn = qps_[i]->get_qp_num(); dest_info_[i].psn = 0; dest_info_[i].gid = gid_; } // All-gather dest info - bootstrap_->allGather(dest_info_.data(), sizeof(QPDestInfo)); + bootstrap_->all_gather(dest_info_.data(), sizeof(rdma::qp_dest_info_t)); DEBUG_PRINT("QP destination info exchanged"); } - void modifyQPsResetToInit() { + void modify_qps_reset_to_init() { DEBUG_PRINT("Transitioning QPs: RESET -> INIT"); struct ibv_qp_attr attr; @@ -673,12 +673,12 @@ class NetworkBackend { IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS; for (int i = 0; i < world_size_; i++) { - int err = ibv_modify_qp(qps_[i]->getIBVQP(), &attr, attr_mask); + int err = ibv_modify_qp(qps_[i]->get_ibv_qp(), &attr, attr_mask); CHECK_ZERO(err, "modify_qp (RESET->INIT)"); } } - void modifyQPsInitToRTR() { + void modify_qps_init_to_rtr() { DEBUG_PRINT("Transitioning QPs: INIT -> RTR"); struct ibv_qp_attr attr; @@ -712,12 +712,12 @@ class NetworkBackend { attr.ah_attr.dlid = dest_info_[i].lid; } - int err = ibv_modify_qp(qps_[i]->getIBVQP(), &attr, attr_mask); + int err = ibv_modify_qp(qps_[i]->get_ibv_qp(), &attr, attr_mask); CHECK_ZERO(err, "modify_qp (INIT->RTR)"); } } - void modifyQPsRTRToRTS() { + void modify_qps_rtr_to_rts() { DEBUG_PRINT("Transitioning QPs: RTR -> RTS"); struct ibv_qp_attr attr; @@ -735,11 +735,11 @@ class NetworkBackend { for (int i = 0; i < world_size_; i++) { attr.sq_psn = dest_info_[i].psn; - int err = ibv_modify_qp(qps_[i]->getIBVQP(), &attr, attr_mask); + int err = ibv_modify_qp(qps_[i]->get_ibv_qp(), &attr, attr_mask); CHECK_ZERO(err, "modify_qp (RTR->RTS)"); } } }; -} // namespace iris_rdma +} // namespace iris diff --git a/iris/experimental/iris_rdma/src/queue.hpp b/iris/experimental/iris_rdma/src/queue.hpp index 326335be..f52339e6 100644 --- a/iris/experimental/iris_rdma/src/queue.hpp +++ b/iris/experimental/iris_rdma/src/queue.hpp @@ -1,8 +1,7 @@ // GPU-to-CPU Queue - C++ Host Side // Exposes queue pointer to Python/Triton -#ifndef QUEUE_HPP_ -#define QUEUE_HPP_ +#pragma once #include @@ -14,10 +13,11 @@ #include #include -namespace gpu_cpu_queue { +namespace iris { +namespace rdma { // Operation types - simplified for Iris -enum class OperationType : uint8_t { +enum class operation_type : uint8_t { NOP = 0, PUT = 1, // RDMA write GET = 2, // RDMA read @@ -26,23 +26,23 @@ enum class OperationType : uint8_t { // Work item structure - metadata only, no data storage // Data is stored in the registered symmetric heap -struct alignas(16) WorkItemHeader { +struct alignas(16) work_item_header_t { uint64_t dst_ptr; // Destination pointer (where to write on remote) uint64_t src_ptr; // Source pointer (offset in local registered heap) uint32_t size_bytes; // Size in bytes to transfer (WRITE LAST as ready flag) uint16_t rank; // Remote rank - uint8_t op_type; // Operation type (see OperationType enum) + uint8_t op_type; // Operation type (see operation_type enum) uint8_t reserved; // Reserved for future use }; // Note: Completion is signaled by tail pointer advancement, not a flag -struct alignas(16) WorkItem { - WorkItemHeader header; +struct alignas(16) work_item_t { + work_item_header_t header; }; // Queue state visible to both CPU and GPU -struct QueueState { - WorkItem* items; // Queue buffer (pinned host memory) +struct queue_state_t { + work_item_t* items; // Queue buffer (pinned host memory) uint64_t* head; // Head pointer (device memory, GPU writes) uint64_t* tail; // Tail pointer (host memory, CPU writes, GPU reads) uint64_t* tailCache; // Cached tail (device memory) @@ -50,15 +50,15 @@ struct QueueState { }; // CPU-side queue management -class Queue { +class queue { public: - explicit Queue(int size = 512) : size_(size) { - // Allocate pinned memory for QueueState struct (GPU needs to read this) - hipHostMalloc(&state_, sizeof(QueueState)); + explicit queue(int size = 512) : size_(size) { + // Allocate pinned memory for queue_state_t struct (GPU needs to read this) + hipHostMalloc(&state_, sizeof(queue_state_t)); // Allocate pinned memory for queue items - hipHostMalloc(&state_->items, size * sizeof(WorkItem)); - memset(state_->items, 0, size * sizeof(WorkItem)); + hipHostMalloc(&state_->items, size * sizeof(work_item_t)); + memset(state_->items, 0, size * sizeof(work_item_t)); // Allocate device memory for head hipMalloc(&state_->head, sizeof(uint64_t)); @@ -75,7 +75,7 @@ class Queue { state_->size = size; } - ~Queue() { + ~queue() { hipHostFree(state_->items); hipFree(state_->head); hipHostFree(state_->tail); @@ -84,12 +84,12 @@ class Queue { } // Get raw pointer to queue state for Triton - QueueState* getQueuePtr() { return state_; } + queue_state_t* get_queue_ptr() { return state_; } // Poll for new work item (non-blocking) - bool poll(WorkItem& item) { + bool poll(work_item_t& item) { uint64_t currentTail = *state_->tail; - WorkItem* ptr = &state_->items[currentTail % size_]; + work_item_t* ptr = &state_->items[currentTail % size_]; // Atomic load of size_bytes (acquire semantics) - use as ready flag // size_bytes == 0 means slot is empty/processed @@ -102,7 +102,7 @@ class Queue { } // Copy entire work item (just header now, no data array) - memcpy(&item, ptr, sizeof(WorkItem)); + memcpy(&item, ptr, sizeof(work_item_t)); return true; } @@ -120,28 +120,27 @@ class Queue { } // Get queue statistics - uint64_t getTail() const { return *state_->tail; } + uint64_t get_tail() const { return *state_->tail; } - uint64_t getHead() const { + uint64_t get_head() const { uint64_t h; hipMemcpy(&h, state_->head, sizeof(uint64_t), hipMemcpyDeviceToHost); return h; } - int getSize() const { return size_; } + int get_size() const { return size_; } // Check if queue is empty (all work processed) - bool isEmpty() const { + bool is_empty() const { uint64_t h; hipMemcpy(&h, state_->head, sizeof(uint64_t), hipMemcpyDeviceToHost); return h == *state_->tail; } private: - QueueState* state_; + queue_state_t* state_; int size_; }; -} // namespace gpu_cpu_queue - -#endif // QUEUE_HPP_ +} // namespace rdma +} // namespace iris diff --git a/iris/experimental/iris_rdma/src/queue_pair.hpp b/iris/experimental/iris_rdma/src/queue_pair.hpp index dff923c9..ecd90c53 100644 --- a/iris/experimental/iris_rdma/src/queue_pair.hpp +++ b/iris/experimental/iris_rdma/src/queue_pair.hpp @@ -5,7 +5,7 @@ #include #include "ibv_utils.hpp" -namespace iris_rdma { +namespace iris { /** * @brief Simplified Queue Pair wrapper for host-side operations @@ -13,7 +13,7 @@ namespace iris_rdma { * Unlike the full rocSHMEM QueuePair, this version only maintains * metadata needed for RDMA operations from Python/host code. */ -class QueuePair { +class queue_pair { public: /** * @brief Constructor @@ -22,10 +22,10 @@ class QueuePair { * @param dst_rank Destination rank for this QP * @param vendor NIC vendor type */ - inline QueuePair(struct ibv_qp* qp, - struct ibv_cq* cq, - int dst_rank, - NICVendor vendor) + inline queue_pair(struct ibv_qp* qp, + struct ibv_cq* cq, + int dst_rank, + rdma::nic_vendor vendor) : qp_(qp), cq_(cq), dst_rank_(dst_rank), @@ -41,55 +41,55 @@ class QueuePair { /** * @brief Destructor */ - inline ~QueuePair() { - DEBUG_PRINT("QueuePair destroyed: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); + inline ~queue_pair() { + DEBUG_PRINT("queue_pair destroyed: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); } /** * @brief Get QP number */ - uint32_t getQPNum() const { return qp_num_; } + uint32_t get_qp_num() const { return qp_num_; } /** * @brief Get local key for memory region */ - uint32_t getLKey() const { return lkey_; } + uint32_t get_lkey() const { return lkey_; } /** * @brief Get remote key for destination rank */ - uint32_t getRKey() const { return rkey_; } + uint32_t get_rkey() const { return rkey_; } /** * @brief Get destination rank */ - int getDstRank() const { return dst_rank_; } + int get_dst_rank() const { return dst_rank_; } /** * @brief Set remote key (after exchange) */ - void setRKey(uint32_t rkey) { rkey_ = rkey; } + void set_rkey(uint32_t rkey) { rkey_ = rkey; } /** * @brief Set local key (from memory registration) */ - void setLKey(uint32_t lkey) { lkey_ = lkey; } + void set_lkey(uint32_t lkey) { lkey_ = lkey; } /** * @brief Get underlying ibv_qp pointer */ - struct ibv_qp* getIBVQP() { return qp_; } + struct ibv_qp* get_ibv_qp() { return qp_; } /** * @brief Get underlying ibv_cq pointer */ - struct ibv_cq* getIBVCQ() { return cq_; } + struct ibv_cq* get_ibv_cq() { return cq_; } /** * @brief Get QP info for Python */ - inline QPInfo getInfo() const { - QPInfo info; + inline rdma::qp_info_t get_info() const { + rdma::qp_info_t info; info.qp_num = qp_num_; info.lkey = lkey_; info.rkey = rkey_; @@ -101,12 +101,12 @@ class QueuePair { struct ibv_qp* qp_; struct ibv_cq* cq_; int dst_rank_; - NICVendor vendor_; + rdma::nic_vendor vendor_; uint32_t qp_num_; uint32_t lkey_; uint32_t rkey_; }; -} // namespace iris_rdma +} // namespace iris diff --git a/iris/experimental/iris_rdma/src/torch_bootstrap.hpp b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp index e714b8dd..73aca004 100644 --- a/iris/experimental/iris_rdma/src/torch_bootstrap.hpp +++ b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp @@ -10,39 +10,40 @@ #include #include "ibv_utils.hpp" -namespace iris_rdma { +namespace iris { +namespace rdma { /** * @brief Bootstrap implementation using PyTorch Distributed * * Wraps PyTorch's c10d process group to provide synchronization - * primitives needed for InfiniBand setup (allGather, barrier) + * primitives needed for InfiniBand setup (all_gather, barrier) */ -class TorchBootstrap { +class torch_bootstrap { public: /** * @brief Constructor * @param process_group PyTorch distributed process group */ - inline explicit TorchBootstrap(c10::intrusive_ptr process_group) + inline explicit torch_bootstrap(c10::intrusive_ptr process_group) : process_group_(process_group) { if (!process_group_) { throw std::runtime_error("Process group cannot be null"); } rank_ = process_group_->getRank(); world_size_ = process_group_->getSize(); - DEBUG_PRINT("TorchBootstrap initialized: rank=%d, world_size=%d", rank_, world_size_); + DEBUG_PRINT("torch_bootstrap initialized: rank=%d, world_size=%d", rank_, world_size_); } /** * @brief Get rank of current process */ - int getRank() const { return rank_; } + int get_rank() const { return rank_; } /** * @brief Get total number of ranks */ - int getWorldSize() const { return world_size_; } + int get_world_size() const { return world_size_; } /** * @brief All-gather operation @@ -53,7 +54,7 @@ class TorchBootstrap { * @param allData Buffer to hold all gathered data (world_size * size bytes) * @param size Size of data contributed by each rank */ - inline void allGather(void* allData, int size) { + inline void all_gather(void* allData, int size) { auto cpu_options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); auto cuda_options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -118,5 +119,6 @@ class TorchBootstrap { int world_size_; }; -} // namespace iris_rdma +} // namespace rdma +} // namespace iris From da140815c190a6f585f5124e8ec77fb5570715c5 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 22:23:48 -0500 Subject: [PATCH 14/16] Add logger --- iris/experimental/iris_rdma/src/ibv_utils.hpp | 24 ++--- .../iris_rdma/src/iris_manager.hpp | 28 +++--- iris/experimental/iris_rdma/src/logging.hpp | 89 +++++++++++++++++++ .../iris_rdma/src/network_backend.hpp | 54 +++++------ .../experimental/iris_rdma/src/queue_pair.hpp | 4 +- .../iris_rdma/src/torch_bootstrap.hpp | 2 +- 6 files changed, 142 insertions(+), 59 deletions(-) create mode 100644 iris/experimental/iris_rdma/src/logging.hpp diff --git a/iris/experimental/iris_rdma/src/ibv_utils.hpp b/iris/experimental/iris_rdma/src/ibv_utils.hpp index b922572e..7a9632f1 100644 --- a/iris/experimental/iris_rdma/src/ibv_utils.hpp +++ b/iris/experimental/iris_rdma/src/ibv_utils.hpp @@ -7,6 +7,8 @@ #include #include +#include "logging.hpp" + namespace iris { namespace rdma { @@ -15,8 +17,7 @@ namespace rdma { do { \ int ret = (expr); \ if (ret != 0) { \ - fprintf(stderr, "[ERROR] %s failed with code %d: %s\n", msg, ret, \ - strerror(ret)); \ + LOG_ERROR("%s failed with code %d: %s", msg, ret, strerror(ret)); \ abort(); \ } \ } while (0) @@ -24,18 +25,11 @@ namespace rdma { #define CHECK_NNULL(ptr, msg) \ do { \ if ((ptr) == nullptr) { \ - fprintf(stderr, "[ERROR] %s returned NULL\n", msg); \ + LOG_ERROR("%s returned NULL", msg); \ abort(); \ } \ } while (0) -#define DEBUG_PRINT(fmt, ...) \ - do { \ - if (getenv("IRIS_RDMA_DEBUG")) { \ - fprintf(stderr, "[IRIS_RDMA_DEBUG] " fmt "\n", ##__VA_ARGS__); \ - } \ - } while (0) - // Vendor detection enum class nic_vendor { NONE, IONIC, BNXT, MLX5 }; @@ -57,20 +51,20 @@ struct qp_info_t { // Helper functions inline void dump_ibv_device(struct ibv_device* device) { - DEBUG_PRINT("IBV Device: %s", ibv_get_device_name(device)); + LOG_DEBUG("IBV Device: %s", ibv_get_device_name(device)); } inline void dump_ibv_context(struct ibv_context* ctx) { - DEBUG_PRINT("IBV Context: device=%s", ctx->device->name); + LOG_DEBUG("IBV Context: device=%s", ctx->device->name); } inline void dump_ibv_pd(struct ibv_pd* pd) { - DEBUG_PRINT("IBV PD: handle=%u", pd->handle); + LOG_DEBUG("IBV PD: handle=%u", pd->handle); } inline void dump_ibv_port_attr(struct ibv_port_attr* attr) { - DEBUG_PRINT("Port Attr: state=%d, lid=%d, link_layer=%d, active_mtu=%d", - attr->state, attr->lid, attr->link_layer, attr->active_mtu); + LOG_DEBUG("Port Attr: state=%d, lid=%d, link_layer=%d, active_mtu=%d", + attr->state, attr->lid, attr->link_layer, attr->active_mtu); } inline int ibv_mtu_to_int(enum ibv_mtu mtu) { diff --git a/iris/experimental/iris_rdma/src/iris_manager.hpp b/iris/experimental/iris_rdma/src/iris_manager.hpp index bce4c2ec..4bb24c44 100644 --- a/iris/experimental/iris_rdma/src/iris_manager.hpp +++ b/iris/experimental/iris_rdma/src/iris_manager.hpp @@ -155,29 +155,29 @@ class rdma_proxy { bool is_fp16 = (dtype_env && strcmp(dtype_env, "float16") == 0); bool is_fp32 = (!dtype_env || strcmp(dtype_env, "float32") == 0); - fprintf(stderr, "[DEBUG-%s] rank=%d dst=%d size=%zu ", - op_name, backend_->get_rank(), dst_rank, size); - if (is_bf16 || is_fp16) { // 2-byte types int elem_count = std::min((int)(size / 2), 10); uint16_t* data_ptr = (uint16_t*)data; - fprintf(stderr, "(bf16) src=%lx dst=%lx: ", src_ptr, dst_ptr); + LOG_DATA_DEBUG("[%s] rank=%d dst=%d size=%zu (bf16) src=%lx dst=%lx: first values", + op_name, backend_->get_rank(), dst_rank, size, src_ptr, dst_ptr); for (int i = 0; i < elem_count; i++) { uint32_t fp32_bits = ((uint32_t)data_ptr[i]) << 16; float value = *reinterpret_cast(&fp32_bits); fprintf(stderr, "%.1f ", value); } + fprintf(stderr, "\n"); } else if (is_fp32) { // 4-byte types int elem_count = std::min((int)(size / 4), 10); float* float_ptr = (float*)data; - fprintf(stderr, "(fp32) src=%lx dst=%lx: ", src_ptr, dst_ptr); + LOG_DATA_DEBUG("[%s] rank=%d dst=%d size=%zu (fp32) src=%lx dst=%lx: first values", + op_name, backend_->get_rank(), dst_rank, size, src_ptr, dst_ptr); for (int i = 0; i < elem_count; i++) { fprintf(stderr, "%.1f ", float_ptr[i]); } + fprintf(stderr, "\n"); } - fprintf(stderr, "\n"); } /** @@ -198,14 +198,14 @@ class rdma_proxy { // No memcpy needed - just RDMA directly from heap! void* local_addr = (void*)src_ptr; - DEBUG_PRINT("[IrisManager] PUT: rank=%d src=%lx dst=%lx size=%zu", - dst_rank, src_ptr, dst_ptr, size); + LOG_DEBUG("PUT: rank=%d src=%lx dst=%lx size=%zu", + dst_rank, src_ptr, dst_ptr, size); debug_print_work_item(item); int ret = backend_->rdma_write(dst_rank, local_addr, dst_ptr, size); if (ret != 0) { - fprintf(stderr, "[IrisManager] RDMA write failed: dst=%d size=%lu\n", dst_rank, size); + LOG_ERROR("RDMA write failed: dst=%d size=%lu", dst_rank, size); } else { // Poll for completion int n = 0; @@ -229,12 +229,12 @@ class rdma_proxy { // GPU will read from heap after completion void* local_addr = (void*)src_ptr; - DEBUG_PRINT("[IrisManager] GET: rank=%d src=%lx dst=%lx size=%zu", - dst_rank, dst_ptr, src_ptr, size); + LOG_DEBUG("GET: rank=%d src=%lx dst=%lx size=%zu", + dst_rank, dst_ptr, src_ptr, size); int ret = backend_->rdma_read(dst_rank, local_addr, dst_ptr, size); if (ret != 0) { - fprintf(stderr, "[IrisManager] RDMA read failed: dst=%d size=%lu\n", dst_rank, size); + LOG_ERROR("RDMA read failed: dst=%d size=%lu", dst_rank, size); } else { // Poll for completion int n = 0; @@ -255,7 +255,7 @@ class rdma_proxy { case rdma::operation_type::FLUSH: { // Flush all pending operations for this rank - DEBUG_PRINT("[IrisManager] FLUSH: rank=%d", dst_rank); + LOG_DEBUG("FLUSH: rank=%d", dst_rank); int total = 0; int n; @@ -269,7 +269,7 @@ class rdma_proxy { } default: - fprintf(stderr, "[IrisManager] Unknown operation type: %d\n", item.header.op_type); + LOG_ERROR("Unknown operation type: %d", item.header.op_type); queue_->pop(); } } diff --git a/iris/experimental/iris_rdma/src/logging.hpp b/iris/experimental/iris_rdma/src/logging.hpp new file mode 100644 index 00000000..80a968f9 --- /dev/null +++ b/iris/experimental/iris_rdma/src/logging.hpp @@ -0,0 +1,89 @@ +#pragma once + +#include +#include +#include + +namespace iris { +namespace rdma { + +// Log levels +enum class log_level { + DEBUG = 0, + INFO = 1, + WARN = 2, + ERROR = 3, + NONE = 4 +}; + +// Global log level (can be set via environment variable) +inline log_level get_log_level() { + static log_level level = []() { + const char* env = std::getenv("IRIS_LOG_LEVEL"); + if (!env) return log_level::INFO; + + if (strcmp(env, "DEBUG") == 0) return log_level::DEBUG; + if (strcmp(env, "INFO") == 0) return log_level::INFO; + if (strcmp(env, "WARN") == 0) return log_level::WARN; + if (strcmp(env, "ERROR") == 0) return log_level::ERROR; + if (strcmp(env, "NONE") == 0) return log_level::NONE; + + return log_level::INFO; + }(); + return level; +} + +// Check if debug data printing is enabled (separate from log level) +inline bool is_debug_data_enabled() { + static bool enabled = (std::getenv("IRIS_DEBUG_DATA") != nullptr); + return enabled; +} + +// Internal logging function +inline void log_message(log_level level, const char* level_str, const char* fmt, ...) { + if (level < get_log_level()) return; + + // Get timestamp + time_t now = time(nullptr); + struct tm* tm_info = localtime(&now); + char time_buf[64]; + strftime(time_buf, sizeof(time_buf), "%Y-%m-%d %H:%M:%S", tm_info); + + // Print level and timestamp + fprintf(stderr, "[%s] [%s] ", time_buf, level_str); + + // Print message + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + + fprintf(stderr, "\n"); + fflush(stderr); +} + +} // namespace rdma +} // namespace iris + +// Logging macros - easy to replace with real logging library later +#define LOG_DEBUG(fmt, ...) \ + iris::rdma::log_message(iris::rdma::log_level::DEBUG, "DEBUG", fmt, ##__VA_ARGS__) + +#define LOG_INFO(fmt, ...) \ + iris::rdma::log_message(iris::rdma::log_level::INFO, "INFO", fmt, ##__VA_ARGS__) + +#define LOG_WARN(fmt, ...) \ + iris::rdma::log_message(iris::rdma::log_level::WARN, "WARN", fmt, ##__VA_ARGS__) + +#define LOG_ERROR(fmt, ...) \ + iris::rdma::log_message(iris::rdma::log_level::ERROR, "ERROR", fmt, ##__VA_ARGS__) + +// For data debugging (separate from regular logging) +#define LOG_DATA_DEBUG(fmt, ...) \ + do { \ + if (iris::rdma::is_debug_data_enabled()) { \ + fprintf(stderr, "[DEBUG-DATA] " fmt "\n", ##__VA_ARGS__); \ + fflush(stderr); \ + } \ + } while (0) + diff --git a/iris/experimental/iris_rdma/src/network_backend.hpp b/iris/experimental/iris_rdma/src/network_backend.hpp index 03a450ab..268c0aae 100644 --- a/iris/experimental/iris_rdma/src/network_backend.hpp +++ b/iris/experimental/iris_rdma/src/network_backend.hpp @@ -63,14 +63,14 @@ class network_backend { } rank_ = bootstrap_->get_rank(); world_size_ = bootstrap_->get_world_size(); - DEBUG_PRINT("network_backend created: rank=%d, world_size=%d", rank_, world_size_); + LOG_INFO("network_backend created: rank=%d, world_size=%d", rank_, world_size_); } /** * @brief Destructor - cleanup InfiniBand resources */ ~network_backend() { - DEBUG_PRINT("network_backend cleanup started"); + LOG_DEBUG("network_backend cleanup started"); qps_.clear(); @@ -118,7 +118,7 @@ class network_backend { * @brief Initialize the network (setup QPs, transition to RTS) */ void init() { - DEBUG_PRINT("network_backend::init() started"); + LOG_INFO("network_backend::init() started"); autodetect_dv_libs(); open_ib_device(); @@ -129,7 +129,7 @@ class network_backend { modify_qps_rtr_to_rts(); bootstrap_->barrier(); - DEBUG_PRINT("NetworkBackend::init() completed"); + LOG_INFO("network_backend::init() completed"); } /** @@ -138,7 +138,7 @@ class network_backend { * @param size Size in bytes */ void register_memory(void* ptr, size_t size) { - DEBUG_PRINT("Registering memory: ptr=%p, size=%zu", ptr, size); + LOG_INFO("Registering memory: ptr=%p, size=%zu", ptr, size); int access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC; @@ -186,8 +186,8 @@ class network_backend { } } - DEBUG_PRINT("Memory registered: lkey=%u, rkey=%u, heap_base=%p", - lkey, heap_mr_->rkey, ptr); + LOG_INFO("Memory registered: lkey=%u, rkey=%u, heap_base=%p", + lkey, heap_mr_->rkey, ptr); } /** @@ -349,7 +349,7 @@ class network_backend { int n = ibv_poll_cq(qp->get_ibv_cq(), num_to_poll, wc); if (n < 0) { - DEBUG_PRINT("CQ poll error for rank %d", dst_rank); + LOG_ERROR("CQ poll error for rank %d", dst_rank); return n; } @@ -407,25 +407,25 @@ class network_backend { // Vendor-specific init void autodetect_dv_libs() { - DEBUG_PRINT("Auto-detecting vendor libraries..."); + LOG_DEBUG("Auto-detecting vendor libraries..."); // Try MLX5 if (mlx5_dv_dl_init() == 0) { vendor_ = rdma::nic_vendor::MLX5; - DEBUG_PRINT("Detected MLX5 vendor"); + LOG_INFO("Detected MLX5 vendor"); return; } // Try BNXT if (bnxt_dv_dl_init() == 0) { vendor_ = rdma::nic_vendor::BNXT; - DEBUG_PRINT("Detected BNXT vendor"); + LOG_INFO("Detected BNXT vendor"); return; } // Default to standard verbs vendor_ = rdma::nic_vendor::NONE; - DEBUG_PRINT("Using standard InfiniBand verbs"); + LOG_INFO("Using standard InfiniBand verbs"); } int mlx5_dv_dl_init() { @@ -457,7 +457,7 @@ class network_backend { } void open_ib_device() { - DEBUG_PRINT("Opening InfiniBand device..."); + LOG_INFO("Opening InfiniBand device..."); struct ibv_device** device_list = nullptr; struct ibv_device* device = nullptr; @@ -511,12 +511,12 @@ class network_backend { ibv_free_device_list(device_list); - DEBUG_PRINT("InfiniBand device opened: %s", - ibv_get_device_name(context_->device)); + LOG_INFO("InfiniBand device opened: %s", + ibv_get_device_name(context_->device)); } void create_parent_domain() { - DEBUG_PRINT("Creating parent domain..."); + LOG_DEBUG("Creating parent domain..."); struct ibv_parent_domain_init_attr pattr; memset(&pattr, 0, sizeof(pattr)); @@ -531,7 +531,7 @@ class network_backend { } void select_gid_index() { - DEBUG_PRINT("Selecting GID index..."); + LOG_DEBUG("Selecting GID index..."); const uint8_t local_gid_prefix[2] = {0xFE, 0x80}; int selected_gid_index = -1; @@ -568,11 +568,11 @@ class network_backend { gid_index_ = selected_gid_index; gid_ = selected_gid; - DEBUG_PRINT("Selected GID index: %d", gid_index_); + LOG_DEBUG("Selected GID index: %d", gid_index_); } void create_queues() { - DEBUG_PRINT("Creating queues..."); + LOG_DEBUG("Creating queues..."); int ncqes = 64; // Number of CQ entries int sq_length = 64; // Send queue length @@ -586,11 +586,11 @@ class network_backend { create_cqs(ncqes); create_qps(sq_length); - DEBUG_PRINT("Created %d queue pairs", world_size_); + LOG_INFO("Created %d queue pairs", world_size_); } void create_cqs(int ncqes) { - DEBUG_PRINT("Creating completion queues: ncqes=%d", ncqes); + LOG_DEBUG("Creating completion queues: ncqes=%d", ncqes); struct ibv_cq_init_attr_ex cq_attr; memset(&cq_attr, 0, sizeof(cq_attr)); @@ -616,7 +616,7 @@ class network_backend { } void create_qps(int sq_length) { - DEBUG_PRINT("Creating queue pairs: sq_length=%d", sq_length); + LOG_DEBUG("Creating queue pairs: sq_length=%d", sq_length); struct ibv_qp_init_attr_ex attr; memset(&attr, 0, sizeof(attr)); @@ -641,7 +641,7 @@ class network_backend { } void exchange_qp_dest_info() { - DEBUG_PRINT("Exchanging QP destination info..."); + LOG_DEBUG("Exchanging QP destination info..."); // Fill local dest info for (int i = 0; i < world_size_; i++) { @@ -654,11 +654,11 @@ class network_backend { // All-gather dest info bootstrap_->all_gather(dest_info_.data(), sizeof(rdma::qp_dest_info_t)); - DEBUG_PRINT("QP destination info exchanged"); + LOG_DEBUG("QP destination info exchanged"); } void modify_qps_reset_to_init() { - DEBUG_PRINT("Transitioning QPs: RESET -> INIT"); + LOG_DEBUG("Transitioning QPs: RESET -> INIT"); struct ibv_qp_attr attr; memset(&attr, 0, sizeof(attr)); @@ -679,7 +679,7 @@ class network_backend { } void modify_qps_init_to_rtr() { - DEBUG_PRINT("Transitioning QPs: INIT -> RTR"); + LOG_DEBUG("Transitioning QPs: INIT -> RTR"); struct ibv_qp_attr attr; memset(&attr, 0, sizeof(attr)); @@ -718,7 +718,7 @@ class network_backend { } void modify_qps_rtr_to_rts() { - DEBUG_PRINT("Transitioning QPs: RTR -> RTS"); + LOG_DEBUG("Transitioning QPs: RTR -> RTS"); struct ibv_qp_attr attr; memset(&attr, 0, sizeof(attr)); diff --git a/iris/experimental/iris_rdma/src/queue_pair.hpp b/iris/experimental/iris_rdma/src/queue_pair.hpp index ecd90c53..e085729a 100644 --- a/iris/experimental/iris_rdma/src/queue_pair.hpp +++ b/iris/experimental/iris_rdma/src/queue_pair.hpp @@ -35,14 +35,14 @@ class queue_pair { CHECK_NNULL(qp_, "QueuePair: ibv_qp"); CHECK_NNULL(cq_, "QueuePair: ibv_cq"); qp_num_ = qp_->qp_num; - DEBUG_PRINT("QueuePair created: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); + LOG_DEBUG("queue_pair created: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); } /** * @brief Destructor */ inline ~queue_pair() { - DEBUG_PRINT("queue_pair destroyed: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); + LOG_DEBUG("queue_pair destroyed: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); } /** diff --git a/iris/experimental/iris_rdma/src/torch_bootstrap.hpp b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp index 73aca004..87bb8019 100644 --- a/iris/experimental/iris_rdma/src/torch_bootstrap.hpp +++ b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp @@ -32,7 +32,7 @@ class torch_bootstrap { } rank_ = process_group_->getRank(); world_size_ = process_group_->getSize(); - DEBUG_PRINT("torch_bootstrap initialized: rank=%d, world_size=%d", rank_, world_size_); + LOG_INFO("torch_bootstrap initialized: rank=%d, world_size=%d", rank_, world_size_); } /** From 211856ca1052ad2d67323379bf790470ba3f91f0 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Thu, 30 Oct 2025 22:25:37 -0500 Subject: [PATCH 15/16] Add missing logger --- iris/experimental/iris_rdma/src/iris_manager.hpp | 4 ++-- .../iris_rdma/src/network_backend.hpp | 16 ++++++++-------- .../iris_rdma/src/torch_bootstrap.hpp | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/iris/experimental/iris_rdma/src/iris_manager.hpp b/iris/experimental/iris_rdma/src/iris_manager.hpp index 4bb24c44..76a4835c 100644 --- a/iris/experimental/iris_rdma/src/iris_manager.hpp +++ b/iris/experimental/iris_rdma/src/iris_manager.hpp @@ -215,7 +215,7 @@ class rdma_proxy { std::this_thread::sleep_for(std::chrono::microseconds(10)); } if (n <= 0) { - DEBUG_PRINT("[IrisManager] Warning: PUT completion not polled (may be OK if async)"); + LOG_DEBUG("Warning: PUT completion not polled (may be OK if async)"); } } @@ -244,7 +244,7 @@ class rdma_proxy { std::this_thread::sleep_for(std::chrono::microseconds(10)); } if (n <= 0) { - DEBUG_PRINT("[IrisManager] Warning: GET completion not polled (may be OK if async)"); + LOG_DEBUG("Warning: GET completion not polled (may be OK if async)"); } } diff --git a/iris/experimental/iris_rdma/src/network_backend.hpp b/iris/experimental/iris_rdma/src/network_backend.hpp index 268c0aae..46d84bb4 100644 --- a/iris/experimental/iris_rdma/src/network_backend.hpp +++ b/iris/experimental/iris_rdma/src/network_backend.hpp @@ -111,7 +111,7 @@ class network_backend { bnxtdv_handle_ = nullptr; } - DEBUG_PRINT("NetworkBackend cleanup completed"); + LOG_DEBUG("NetworkBackend cleanup completed"); } /** @@ -286,8 +286,8 @@ class network_backend { struct ibv_send_wr* bad_wr; int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); - DEBUG_PRINT("RDMA Write to rank %d: local=%p remote=%lx size=%zu ret=%d", - dst_rank, local_addr, remote_addr, size, ret); + LOG_DEBUG("RDMA Write to rank %d: local=%p remote=%lx size=%zu ret=%d", + dst_rank, local_addr, remote_addr, size, ret); return ret; } @@ -326,8 +326,8 @@ class network_backend { struct ibv_send_wr* bad_wr; int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); - DEBUG_PRINT("RDMA Read from rank %d: local=%p remote=%lx size=%zu ret=%d", - dst_rank, local_addr, remote_addr, size, ret); + LOG_DEBUG("RDMA Read from rank %d: local=%p remote=%lx size=%zu ret=%d", + dst_rank, local_addr, remote_addr, size, ret); return ret; } @@ -362,7 +362,7 @@ class network_backend { } } - DEBUG_PRINT("Polled %d completions from rank %d", n, dst_rank); + LOG_DEBUG("Polled %d completions from rank %d", n, dst_rank); return n; } @@ -435,7 +435,7 @@ class network_backend { } if (!mlx5dv_handle_) { - DEBUG_PRINT("Could not open libmlx5.so"); + LOG_DEBUG("Could not open libmlx5.so"); return -1; } @@ -449,7 +449,7 @@ class network_backend { } if (!bnxtdv_handle_) { - DEBUG_PRINT("Could not open libbnxt_re.so"); + LOG_DEBUG("Could not open libbnxt_re.so"); return -1; } diff --git a/iris/experimental/iris_rdma/src/torch_bootstrap.hpp b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp index 87bb8019..081336b3 100644 --- a/iris/experimental/iris_rdma/src/torch_bootstrap.hpp +++ b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp @@ -77,7 +77,7 @@ class torch_bootstrap { std::memcpy(static_cast(allData) + i * size, cpu_output.data_ptr(), size); } - DEBUG_PRINT("AllGather completed: %d bytes per rank", size); + LOG_DEBUG("AllGather completed: %d bytes per rank", size); } /** @@ -88,7 +88,7 @@ class torch_bootstrap { inline void barrier() { auto work = process_group_->barrier(); work->wait(); - DEBUG_PRINT("Barrier completed"); + LOG_DEBUG("Barrier completed"); } /** From 955dc3dd1e81a0fedc391aef13212ddee9ac9756 Mon Sep 17 00:00:00 2001 From: Muhammad Awad Date: Fri, 31 Oct 2025 15:02:43 -0500 Subject: [PATCH 16/16] Write QPs --- examples/23_rdma_consumer_pull/README.md | 88 ++++++++ .../rdma_consumer_pull.py | 198 ++++++++++++++++ examples/24_rdma_atomic_add/README.md | 127 +++++++++++ .../24_rdma_atomic_add/rdma_atomic_add.py | 145 ++++++++++++ iris/experimental/iris_rdma.py | 174 +++++++++++++++ .../iris_rdma/src/iris_manager.hpp | 199 ++++++++++++++++- iris/experimental/iris_rdma/src/logging.hpp | 36 +++ .../iris_rdma/src/network_backend.hpp | 211 +++++++++++++----- iris/experimental/iris_rdma/src/queue.hpp | 15 +- .../experimental/iris_rdma/src/queue_pair.hpp | 11 + rebuild.sh | 4 + run.sh | 7 + 12 files changed, 1144 insertions(+), 71 deletions(-) create mode 100644 examples/23_rdma_consumer_pull/README.md create mode 100644 examples/23_rdma_consumer_pull/rdma_consumer_pull.py create mode 100644 examples/24_rdma_atomic_add/README.md create mode 100644 examples/24_rdma_atomic_add/rdma_atomic_add.py create mode 100755 rebuild.sh create mode 100755 run.sh diff --git a/examples/23_rdma_consumer_pull/README.md b/examples/23_rdma_consumer_pull/README.md new file mode 100644 index 00000000..288b5014 --- /dev/null +++ b/examples/23_rdma_consumer_pull/README.md @@ -0,0 +1,88 @@ +# 23. RDMA Consumer Pull (GET) + +Consumer-pull pattern using InfiniBand RDMA GET operations for multi-node communication. + +## Overview + +This example demonstrates: +- Rank 1 (Server) prepares data in its heap +- Rank 0 (Client) uses RDMA GET to pull data from Rank 1 +- Triton kernel verifies pulled data on Rank 0 + +**Key Difference from Example 22:** +- **Example 22 (PUT)**: Sender initiates - Rank 0 pushes data to Rank 1 +- **Example 23 (GET)**: Receiver initiates - Rank 0 pulls data from Rank 1 + +## Requirements + +- InfiniBand network adapter +- libibverbs-dev installed +- Iris built with RDMA support + +## Architecture + +``` +Rank 1 (Server) Rank 0 (Client) +─────────────── ─────────────── +Data in heap + ↓ +CPU buffer + RDMA GET ←──────────┐ + │ +CPU buffer ←───────────────────────────────────────┘ + ↓ +CPU → GPU + ↓ +verify_kernel() + ↓ verifies +✓ Success +``` + +## Usage + +### Single Node (2 GPUs) +```bash +torchrun --nproc_per_node=2 rdma_consumer_pull.py +``` + +### Multi-Node (2 Nodes, 1 GPU each) +```bash +# Node 0 (Client - pulls data) +torchrun --nnodes=2 --nproc_per_node=1 --node_rank=0 \ + --master_addr= --master_port=29500 \ + rdma_consumer_pull.py + +# Node 1 (Server - provides data) +torchrun --nnodes=2 --nproc_per_node=1 --node_rank=1 \ + --master_addr= --master_port=29500 \ + rdma_consumer_pull.py +``` + +## Expected Output + +``` +[Rank 0/2] Initialized on cuda:0 +[Rank 1/2] Initialized on cuda:1 +[Rank 1] Server: Preparing Data +[Rank 1] Data ready, first 10: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] +[Rank 0] Client: Pulling Data via RDMA GET +[Rank 0] RDMA GET operations enqueued to queue +[Rank 0] Barrier complete, all RDMA operations finished +[Rank 0] Received data first 10: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] +[Rank 0] Verified: 4091/4091 +============================================================ +[Rank 0] SUCCESS! Data pulled correctly via RDMA GET! +``` + +## RDMA GET vs PUT + +### When to use GET: +- **Consumer-initiated**: Receiver decides when to pull data +- **Pull-based flow control**: Consumer controls rate +- **Useful for**: Demand-driven workloads, load balancing + +### When to use PUT: +- **Producer-initiated**: Sender decides when to push data +- **Push-based flow control**: Producer controls rate +- **Useful for**: Pipeline parallelism, streaming workloads + diff --git a/examples/23_rdma_consumer_pull/rdma_consumer_pull.py b/examples/23_rdma_consumer_pull/rdma_consumer_pull.py new file mode 100644 index 00000000..9c1cf3a1 --- /dev/null +++ b/examples/23_rdma_consumer_pull/rdma_consumer_pull.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import os +import sys +import torch +import torch.distributed as dist +import triton +import triton.language as tl +import time + +import iris.experimental.iris_rdma as iris_rdma + + +@triton.jit +def consumer_get_kernel( + local_ptr, + remote_ptr, + n_elements, + src_rank: tl.constexpr, + device_ctx, + BLOCK_SIZE: tl.constexpr, +): + """ + Consumer kernel that enqueues RDMA get operations to pull data. + Uses symmetric heap model: remote_ptr points to same offset in remote heap. + After RDMA get completes, data will be available at local_ptr. + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Local and remote pointers (same offset in symmetric heap) + local_ptrs = local_ptr + offsets + remote_ptrs = remote_ptr + offsets + + # Enqueue RDMA GET operation: pull from remote to local + iris_rdma.get(local_ptrs, remote_ptrs, src_rank, device_ctx, mask) + + +@triton.jit +def verify_kernel( + input_ptr, + result_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """ + Verification kernel that checks received data. + Expected pattern: ascending numbers 0, 1, 2, ..., n_elements-1 + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load received data + data = tl.load(input_ptr + offsets, mask=mask, other=0.0) + + # Check if it matches expected pattern (0, 1, 2, 3, ...) + expected = offsets.to(data.dtype) + is_correct = (data == expected).to(tl.int32) + + tl.store(result_ptr + offsets, is_correct, mask=mask) + + +def main(): + + dtype = torch.bfloat16 + + # Initialize distributed + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + device_id = torch.device(f"cuda:{local_rank}") + + dist.init_process_group( + backend='nccl', + device_id=device_id + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + + if world_size < 2: + print("This example requires at least 2 ranks") + sys.exit(1) + + torch.cuda.set_device(local_rank) + device = f'cuda:{local_rank}' + + print(f"[Rank {rank}/{world_size}] Initialized on {device}") + + # Create Iris RDMA context with queue + heap_size = 1024 * 1024 * 8 # 8MB + queue_size = 512 + ctx = iris_rdma.iris(heap_size=heap_size, queue_size=queue_size) + + print(f"[Rank {rank}] Iris RDMA initialized") + print(f"[Rank {rank}] - Heap base: {ctx.get_heap_base():#x}") + print(f"[Rank {rank}] - Queue ptr: {ctx.get_queue_ptr():#x}") + + # Get device context for Triton kernels + device_ctx = ctx.get_device_context() + + # Allocate buffers in symmetric heap + n_elements = 4091 + BLOCK_SIZE = 256 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + # Allocate on the symmetric heap + local_buffer = ctx.zeros(n_elements, dtype=dtype) + + ctx.barrier() + + # ============================================================ + # SERVER (Rank 1): Prepare data for RDMA get + # ============================================================ + if rank == 1: + print(f"\n[Rank 1] === Server: Preparing Data ===") + + # Fill buffer with data using PyTorch + print(f"[Rank 1] Filling buffer with data...") + local_buffer.copy_(torch.arange(n_elements, dtype=dtype, device=device)) + torch.cuda.synchronize() + print(f"[Rank 1] Data ready, first 10: {local_buffer[:10].tolist()}") + print(f"[Rank 1] Waiting for client to pull data...") + + # ============================================================ + # CLIENT (Rank 0): Pull data using RDMA get + # ============================================================ + if rank == 0: + print(f"\n[Rank 0] === Client: Pulling Data via RDMA GET ===") + src_rank = 1 + + # Launch RDMA GET enqueue kernel + print(f"[Rank 0] Launching RDMA GET kernel to pull from Rank {src_rank}...") + consumer_get_kernel[grid]( + local_buffer, # local destination + local_buffer, # remote source (same offset in symmetric heap) + n_elements, + src_rank=src_rank, + device_ctx=device_ctx, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Wait for GPU to finish enqueueing + torch.cuda.synchronize() + print(f"[Rank 0] RDMA GET operations enqueued to queue") + + ctx.barrier() + print(f"[Rank {rank}] Barrier complete, all RDMA operations finished") + + # ============================================================ + # CLIENT (Rank 0): Verify pulled data + # ============================================================ + if rank == 0: + print(f"\n[Rank 0] === Verifying Pulled Data ===") + + # Show received data + print(f"[Rank 0] Received data first 10: {local_buffer[:10].tolist()}") + + # Verify data (use int32 for result buffer - stores 0 or 1 for correctness) + result_buffer = torch.zeros(n_elements, dtype=torch.int32, device=device) + + verify_kernel[grid]( + local_buffer, + result_buffer, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + + result_cpu = result_buffer.cpu() + num_correct = result_cpu.sum().item() + num_total = n_elements + + print(f"[Rank 0] Verified: {int(num_correct)}/{num_total}") + + if num_correct == num_total: + print(f"\n" + "="*60) + print(f"[Rank 0] SUCCESS! Data pulled correctly via RDMA GET!") + else: + print(f"[Rank 0] FAILED - Data mismatch!") + first_wrong_idx = (result_cpu == 0).nonzero(as_tuple=True)[0] + if len(first_wrong_idx) > 0: + idx = first_wrong_idx[0].item() + print(f"[Rank 0] First wrong at index {idx}") + print(f"[Rank 0] Expected: {idx}") + print(f"[Rank 0] Got: {local_buffer[idx].item()}") + sys.exit(1) + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + diff --git a/examples/24_rdma_atomic_add/README.md b/examples/24_rdma_atomic_add/README.md new file mode 100644 index 00000000..75d0cbf0 --- /dev/null +++ b/examples/24_rdma_atomic_add/README.md @@ -0,0 +1,127 @@ +# RDMA Atomic Add Example + +This example demonstrates RDMA atomic fetch-and-add operations using Iris RDMA. + +## Overview + +In this example: +- **Rank 0** maintains a shared counter in its symmetric heap +- **All ranks** (0 through N-1) atomically increment rank 0's counter +- Each rank adds its own rank number + 1 (i.e., rank 0 adds 1, rank 1 adds 2, etc.) +- The atomic operation returns the old value before incrementing +- Rank 0 verifies the final sum + +## Key Concepts + +### Atomic Fetch-and-Add +```python +iris_rdma.atomic_add( + result_ptr, # Local buffer to store old value + counter_ptr, # Remote counter location (symmetric heap) + increment, # Value to add + dst_rank, # Which rank owns the counter + device_ctx, # Device context + mask, # Triton mask +) +``` + +- **Atomic**: Operation is indivisible - no race conditions +- **Fetch**: Returns the original value before the add +- **Symmetric Heap**: All ranks use same offset, automatically translated + +### Expected Result + +For N ranks, each rank i adds (i+1): +``` +Final counter = 1 + 2 + 3 + ... + N = N × (N+1) / 2 +``` + +For 2 ranks: 1 + 2 = 3 +For 4 ranks: 1 + 2 + 3 + 4 = 10 +For 8 ranks: 1 + 2 + 3 + ... + 8 = 36 + +## Running the Example + +### With 2 ranks: +```bash +torchrun --nproc_per_node=2 examples/24_rdma_atomic_add/rdma_atomic_add.py +``` + +### With 4 ranks: +```bash +torchrun --nproc_per_node=4 examples/24_rdma_atomic_add/rdma_atomic_add.py +``` + +### With debug logging: +```bash +IRIS_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 examples/24_rdma_atomic_add/rdma_atomic_add.py +``` + +## Expected Output + +``` +[Rank 0/2] Initialized on cuda:0 +[Rank 1/2] Initialized on cuda:1 +[Rank 0] Iris RDMA initialized +[Rank 1] Iris RDMA initialized + +[Rank 0] === Testing Atomic Add === +[Rank 0] Initial counter value: 0 +[Rank 0] Waiting for other ranks to increment... + +[Rank 0] Atomically adding 1 to rank 0's counter... +[Rank 1] Atomically adding 2 to rank 0's counter... +[Rank 0] Atomic add completed. Old value was: 0 +[Rank 1] Atomic add completed. Old value was: 1 + +[Rank 0] === Verification === +[Rank 0] Final counter value: 3 +[Rank 0] Expected value: 3 +[Rank 0] Each rank added: [1, 2] + +============================================================ +[Rank 0] SUCCESS! Atomic operations worked correctly! +============================================================ +``` + +## How It Works + +1. **Initialization**: All ranks initialize Iris RDMA with symmetric heaps +2. **Buffer Allocation**: Each rank allocates counter/result buffers at same offset +3. **Atomic Operations**: + - Ranks launch Triton kernels that call `iris_rdma.atomic_add()` + - Triton kernel enqueues atomic operation to device queue + - CPU proxy thread dequeues and executes RDMA atomic via InfiniBand + - Original value is returned to result buffer +4. **Verification**: Rank 0 checks that sum equals expected value + +## Key Features Demonstrated + +- ✅ **RDMA Atomics**: Hardware-level atomic operations over InfiniBand +- ✅ **Symmetric Heap**: Automatic address translation between ranks +- ✅ **Fetch-and-Add**: Returns old value atomically +- ✅ **GPU-initiated**: Triton kernel directly initiates RDMA operations +- ✅ **Zero-copy**: No intermediate buffers or CPU involvement for data path + +## Notes + +- Atomic operations require **64-bit integers** (`torch.int64` or `torch.uint64`) +- 32-bit atomics are also supported by changing the size parameter +- Operations are **synchronous** - kernel waits for completion before returning +- All ranks must allocate buffers at the **same symmetric heap offset** + +## Troubleshooting + +**Counter value is wrong:** +- Check that all ranks successfully performed atomic operations +- Verify InfiniBand connection is working +- Enable debug logging to see RDMA operations + +**Atomics not supported error:** +- Ensure your InfiniBand HCA supports atomic operations +- Most modern Mellanox/NVIDIA and Broadcom NICs support this + +**Hang on barrier:** +- Check that all ranks are running +- Verify NCCL is properly configured + diff --git a/examples/24_rdma_atomic_add/rdma_atomic_add.py b/examples/24_rdma_atomic_add/rdma_atomic_add.py new file mode 100644 index 00000000..824a678e --- /dev/null +++ b/examples/24_rdma_atomic_add/rdma_atomic_add.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +""" +RDMA Atomic Add Example + +Demonstrates RDMA atomic fetch-and-add operations between ranks. +Each rank atomically increments a counter on rank 0. +""" + +import os +import sys +import torch +import torch.distributed as dist +import triton +import triton.language as tl + +from iris.experimental import iris_rdma + + +@triton.jit +def atomic_add_kernel( + counter_ptr, + result_ptr, + increment, + dst_rank: tl.constexpr, + device_ctx, +): + """ + Each thread atomically adds its increment to the remote counter. + Returns the old value before increment. + """ + pid = tl.program_id(0) + + # Only first thread does the atomic add + if pid == 0: + # Create a mask for single element operation + mask = tl.full([1], 1, dtype=tl.int1) + + # Atomic add: increment counter on dst_rank, get old value + iris_rdma.atomic_add( + result_ptr, # Where to store old value + counter_ptr, # Remote counter location (symmetric heap) + increment, # Value to add + dst_rank, # Which rank has the counter + device_ctx, + mask, + ) + + +def main(): + dtype = torch.int64 # Atomics require int64/uint64 + + # Initialize distributed + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + device_id = torch.device(f"cuda:{local_rank}") + + dist.init_process_group( + backend='nccl', + device_id=device_id + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank}") + + print(f"[Rank {rank}/{world_size}] Initialized on {device}") + + # Initialize RDMA context + ctx = iris_rdma.IrisRDMA() + print(f"[Rank {rank}] Iris RDMA initialized") + print(f"[Rank {rank}] - Heap base: {ctx.get_heap_base():#x}") + print(f"[Rank {rank}] - Queue ptr: {ctx.get_queue_ptr():#x}") + + # Get device context for Triton kernels + device_ctx = ctx.get_device_context() + + # Allocate counter and result buffer in symmetric heap + counter = ctx.zeros(1, dtype=dtype) # Shared counter + result = ctx.zeros(1, dtype=dtype) # Store old value + + ctx.barrier() + + # ============================================================ + # Rank 0 atomically increments rank 1's counter + # ============================================================ + print(f"\n[Rank {rank}] === Testing Atomic Add ===") + + if rank == 1: + print(f"[Rank 1] Initial counter value: {counter[0].item()}") + print(f"[Rank 1] Waiting for rank 0 to increment...") + + ctx.barrier() + + # Only rank 0 performs the atomic operation (to avoid local atomic on rank 1) + if rank == 0: + increment = 42 # Arbitrary test value + target_rank = 1 + print(f"[Rank 0] Atomically adding {increment} to rank {target_rank}'s counter...") + + # Launch atomic add kernel + grid = (1,) # Single thread + atomic_add_kernel[grid]( + counter, # Counter location (same offset on all ranks) + result, # Where to store old value + increment, # Value to add + dst_rank=target_rank, + device_ctx=device_ctx, + ) + + # Synchronize GPU + torch.cuda.synchronize() + + # Read the old value returned by atomic add + old_value = result.cpu()[0].item() + print(f"[Rank 0] Atomic add completed. Old value was: {old_value}") + + ctx.barrier() + + # ============================================================ + # Rank 1: Verify final counter value + # ============================================================ + if rank == 1: + print(f"\n[Rank 1] === Verification ===") + final_value = counter.cpu()[0].item() + expected = 42 # Only rank 0 added 42 + + print(f"[Rank 1] Final counter value: {final_value}") + print(f"[Rank 1] Expected value: {expected}") + + if final_value == expected: + print("\n" + "="*60) + print("[Rank 1] SUCCESS! RDMA atomic add worked correctly!") + print("="*60) + else: + print(f"[Rank 1] FAILED - Counter value mismatch!") + print(f"[Rank 1] Expected: {expected}") + print(f"[Rank 1] Got: {final_value}") + sys.exit(1) + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + diff --git a/iris/experimental/iris_rdma.py b/iris/experimental/iris_rdma.py index e82224f2..45b2adb9 100644 --- a/iris/experimental/iris_rdma.py +++ b/iris/experimental/iris_rdma.py @@ -639,10 +639,184 @@ def get(dst_ptr, src_ptr, from_rank: tl.constexpr, device_ctx, mask): # Data is now ready at dst_ptr (CPU has written it there via RDMA) +@triton.jit +def atomic_add(result_ptr, dst_ptr, add_value, dst_rank: tl.constexpr, device_ctx, mask): + """ + RDMA atomic fetch-and-add operation from Triton kernel. + + Atomically adds a value to remote memory and returns the original value. + Uses symmetric heap model: dst_ptr is in current rank's address space, + and will be automatically translated to remote rank's address space. + + Args: + result_ptr: Local pointer where the original value will be stored + dst_ptr: Destination pointer in CURRENT rank's address space (symmetric heap) + add_value: Value to add (must be scalar uint64 or int64) + dst_rank: Destination rank ID (must be compile-time constant) + device_ctx: Device context from iris_rdma.get_device_context() + mask: Triton mask for valid elements + + Note: Only supports 8-byte (uint64/int64) atomic operations. + The result_ptr will contain the original value before the add. + """ + # Extract context fields + my_rank = tl.load(device_ctx + 0) + queue_ptr = tl.load(device_ctx + 2) + heap_bases = device_ctx + 3 + + # Translate dst_ptr from current rank's address space to remote rank's + translated_dst_ptr = _translate(dst_ptr, my_rank, dst_rank, heap_bases) + + # Enqueue ATOMIC_ADD operation (op_code=4) + # For ATOMIC_ADD: result_ptr is local result buffer, translated_dst_ptr is remote target + queue_pos = _enqueue_atomic_op(result_ptr, translated_dst_ptr, dst_rank, 4, + add_value, 0, queue_ptr, mask) + + # Wait for CPU to complete the atomic operation + _wait_for_completion(queue_ptr, queue_pos) + + # Result is now ready at result_ptr (original value before add) + + +@triton.jit +def atomic_cas(result_ptr, dst_ptr, compare_value, swap_value, dst_rank: tl.constexpr, device_ctx, mask): + """ + RDMA atomic compare-and-swap operation from Triton kernel. + + Atomically compares remote memory with expected value and swaps if equal. + Returns the original value. Uses symmetric heap model. + + Args: + result_ptr: Local pointer where the original value will be stored + dst_ptr: Destination pointer in CURRENT rank's address space (symmetric heap) + compare_value: Expected value (must be scalar uint64 or int64) + swap_value: New value if comparison succeeds (must be scalar uint64 or int64) + dst_rank: Destination rank ID (must be compile-time constant) + device_ctx: Device context from iris_rdma.get_device_context() + mask: Triton mask for valid elements + + Note: Only supports 8-byte (uint64/int64) atomic operations. + The result_ptr will contain the original value at the remote location. + If result == compare_value, the swap succeeded. + """ + # Extract context fields + my_rank = tl.load(device_ctx + 0) + queue_ptr = tl.load(device_ctx + 2) + heap_bases = device_ctx + 3 + + # Translate dst_ptr from current rank's address space to remote rank's + translated_dst_ptr = _translate(dst_ptr, my_rank, dst_rank, heap_bases) + + # Enqueue ATOMIC_CAS operation (op_code=6) + queue_pos = _enqueue_atomic_op(result_ptr, translated_dst_ptr, dst_rank, 6, + swap_value, compare_value, queue_ptr, mask) + + # Wait for CPU to complete the atomic operation + _wait_for_completion(queue_ptr, queue_pos) + + # Result is now ready at result_ptr (original value from remote) + + +@triton.jit +def _enqueue_atomic_op(result_ptr, dst_ptr, to_rank: tl.constexpr, op_code: tl.constexpr, + operand, compare, queue_ptr, mask): + """ + Internal: Enqueue an atomic RDMA operation to the queue. + + Args: + result_ptr: Local pointer for result + dst_ptr: Destination pointer on remote rank (already translated) + to_rank: Target rank ID + op_code: Operation type (4=ATOMIC_ADD, 6=ATOMIC_CAS) + operand: Operand value (add_value or swap_value) + compare: Compare value (0 for ADD, compare_value for CAS) + queue_ptr: Queue pointer from device context + mask: Triton mask for valid elements + """ + state_ptr = queue_ptr.to(tl.pointer_type(tl.uint64)) + + # Load QueueState fields + items_ptr = tl.load(state_ptr + 0) + head_ptr = tl.load(state_ptr + 1) + tail_ptr = tl.load(state_ptr + 2) + + # Load size (at offset 32 bytes = 4 * uint64) + size_ptr = queue_ptr.to(tl.pointer_type(tl.int32)) + size = tl.load(size_ptr + 8) + + # Atomic increment head to reserve slot + head_ptr_typed = head_ptr.to(tl.pointer_type(tl.uint64)) + prev_head = tl.atomic_add(head_ptr_typed, 1, sem='relaxed', scope='sys') + + # Wait for slot to be free + size_u64 = size.to(tl.uint64) + tail_ptr_typed = tail_ptr.to(tl.pointer_type(tl.uint64)) + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + while prev_head >= size_u64 + current_tail: + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + # Calculate slot position + slot_idx = prev_head % size_u64 + + # WorkItem structure (48 bytes total): + # Header (32 bytes due to alignas(16)): + # offset 0: uint64_t dst_ptr + # offset 8: uint64_t src_ptr (result_ptr for atomics) + # offset 16: uint32_t size_bytes (WRITE LAST as ready flag) + # offset 20: uint16_t rank + # offset 22: uint8_t op_type + # offset 23: uint8_t reserved + # offset 24-31: padding (alignas(16) pads header to 32 bytes) + # Atomic fields (16 bytes): + # offset 32: uint64_t atomic_operand + # offset 40: uint64_t atomic_compare + WORK_ITEM_SIZE_BYTES = 48 # Header (32 with padding) + atomic fields (16) + + slot_offset_bytes = slot_idx * WORK_ITEM_SIZE_BYTES + + # Get pointer to this work item + items_ptr_u64 = items_ptr.to(tl.pointer_type(tl.uint64)) + slot_ptr_u64 = items_ptr_u64 + (slot_offset_bytes // 8).to(tl.int32) + + # Cast pointers to uint64 + dst_ptr_val = tl.cast(dst_ptr, tl.uint64) + result_ptr_val = tl.cast(result_ptr, tl.uint64) + operand_u64 = tl.cast(operand, tl.uint64) + compare_u64 = tl.cast(compare, tl.uint64) + + # Write WorkItem fields (except size which is written last as ready flag) + # Offset 0: dst_ptr (remote target) + tl.store(slot_ptr_u64 + 0, dst_ptr_val) + + # Offset 8: src_ptr (result buffer) + tl.store(slot_ptr_u64 + 1, result_ptr_val) + + # Offset 32: atomic_operand (offset 32 bytes = 4 * 8 bytes) + tl.store(slot_ptr_u64 + 4, operand_u64) + + # Offset 40: atomic_compare (offset 40 bytes = 5 * 8 bytes) + tl.store(slot_ptr_u64 + 5, compare_u64) + + # Offset 20 (bytes): Pack rank (16 bits) + op_type (8 bits) into 32 bits + # Same as regular RDMA operations + slot_ptr_u32 = slot_ptr_u64.to(tl.pointer_type(tl.uint32)) + metadata = (to_rank & 0xFFFF) | ((op_code & 0xFF) << 16) + tl.store(slot_ptr_u32 + 5, metadata) # offset 20 bytes = 5 * 4 bytes + + # Offset 16 (bytes) / 4 (uint32): size_bytes - WRITE LAST as ready flag + # For atomics, size is always 8 bytes + tl.store(slot_ptr_u32 + 4, tl.cast(8, tl.uint32)) # offset 16 bytes = 4 * 4 bytes + + return prev_head + + __all__ = [ "IrisRDMA", "iris", "put", "get", + "atomic_add", + "atomic_cas", ] diff --git a/iris/experimental/iris_rdma/src/iris_manager.hpp b/iris/experimental/iris_rdma/src/iris_manager.hpp index 76a4835c..ac6c0d8d 100644 --- a/iris/experimental/iris_rdma/src/iris_manager.hpp +++ b/iris/experimental/iris_rdma/src/iris_manager.hpp @@ -15,11 +15,28 @@ #include #include +#include #include "network_backend.hpp" #include "queue.hpp" namespace iris { +/** + * @brief Get maximum number of polling attempts from environment variable + * @return Max attempts (default 100, configurable via IRIS_RDMA_POLL_MAX_ATTEMPTS) + */ +inline int get_max_poll_attempts() { + static int max_attempts = []() { + const char* env = std::getenv("IRIS_RDMA_POLL_MAX_ATTEMPTS"); + if (env) { + int val = std::atoi(env); + if (val > 0) return val; + } + return 100; // Default + }(); + return max_attempts; +} + /** * @brief Complete Iris RDMA Proxy * @@ -180,6 +197,53 @@ class rdma_proxy { } } + /** + * @brief Convert operation type to string + */ + const char* op_type_to_string(uint8_t op_type) { + switch (static_cast(op_type)) { + case rdma::operation_type::NOP: return "NOP"; + case rdma::operation_type::PUT: return "PUT"; + case rdma::operation_type::GET: return "GET"; + case rdma::operation_type::FLUSH: return "FLUSH"; + case rdma::operation_type::ATOMIC_ADD: return "ATOMIC_ADD"; + case rdma::operation_type::ATOMIC_EXCH: return "ATOMIC_EXCH"; + case rdma::operation_type::ATOMIC_CAS: return "ATOMIC_CAS"; + default: return "UNKNOWN"; + } + } + + /** + * @brief Dump raw work item bytes for debugging + */ + void dump_work_item_raw(const rdma::work_item_t& item) { + if (!iris::rdma::is_debug_data_enabled()) return; + + const uint8_t* bytes = reinterpret_cast(&item); + fprintf(stderr, "[DEBUG-DATA] Raw WorkItem (48 bytes):\n"); + fprintf(stderr, "[DEBUG-DATA] Header (32 bytes with alignas(16) padding):\n"); + fprintf(stderr, "[DEBUG-DATA] [0-7] dst_ptr: 0x%016lx\n", item.header.dst_ptr); + fprintf(stderr, "[DEBUG-DATA] [8-15] src_ptr: 0x%016lx\n", item.header.src_ptr); + fprintf(stderr, "[DEBUG-DATA] [16-19] size_bytes: %u\n", item.header.size_bytes); + fprintf(stderr, "[DEBUG-DATA] [20-21] rank: %u\n", item.header.rank); + fprintf(stderr, "[DEBUG-DATA] [22] op_type: %u (%s)\n", + item.header.op_type, op_type_to_string(item.header.op_type)); + fprintf(stderr, "[DEBUG-DATA] [23] reserved: %u\n", item.header.reserved); + fprintf(stderr, "[DEBUG-DATA] [24-31] padding (alignas)\n"); + fprintf(stderr, "[DEBUG-DATA] Atomic fields (16 bytes):\n"); + fprintf(stderr, "[DEBUG-DATA] [32-39] operand: 0x%016lx (%lu)\n", + item.atomic_operand, item.atomic_operand); + fprintf(stderr, "[DEBUG-DATA] [40-47] compare: 0x%016lx (%lu)\n", + item.atomic_compare, item.atomic_compare); + fprintf(stderr, "[DEBUG-DATA] Raw bytes: "); + for (int i = 0; i < 48; i++) { + fprintf(stderr, "%02x ", bytes[i]); + if ((i + 1) % 8 == 0) fprintf(stderr, " "); + } + fprintf(stderr, "\n"); + fflush(stderr); + } + /** * @brief Process a single work item from the queue */ @@ -187,6 +251,13 @@ class rdma_proxy { auto op_type = static_cast(item.header.op_type); int dst_rank = item.header.rank; + // Dump raw packet for atomic operations + if (op_type == rdma::operation_type::ATOMIC_ADD || + op_type == rdma::operation_type::ATOMIC_EXCH || + op_type == rdma::operation_type::ATOMIC_CAS) { + dump_work_item_raw(item); + } + // Get addresses from queue metadata uint64_t src_ptr = item.header.src_ptr; // Pointer/offset in registered heap uint64_t dst_ptr = item.header.dst_ptr; // Remote destination @@ -209,7 +280,8 @@ class rdma_proxy { } else { // Poll for completion int n = 0; - for (int attempt = 0; attempt < 100; attempt++) { + int max_attempts = get_max_poll_attempts(); + for (int attempt = 0; attempt < max_attempts; attempt++) { n = backend_->poll_cq(dst_rank, 1); if (n > 0) break; std::this_thread::sleep_for(std::chrono::microseconds(10)); @@ -225,20 +297,24 @@ class rdma_proxy { } case rdma::operation_type::GET: { - // RDMA Read: Read from remote directly into registered heap at src_ptr - // GPU will read from heap after completion - void* local_addr = (void*)src_ptr; + // RDMA Read: Read from remote into local + // NOTE: WorkItem field naming is confusing for GET! + // WorkItem.dst_ptr contains REMOTE source (translated by Triton kernel) + // WorkItem.src_ptr contains LOCAL destination + void* local_addr = (void*)src_ptr; // src_ptr field has local dest + uint64_t remote_addr = dst_ptr; // dst_ptr field has remote source - LOG_DEBUG("GET: rank=%d src=%lx dst=%lx size=%zu", - dst_rank, dst_ptr, src_ptr, size); + LOG_DEBUG("GET: rank=%d remote_src=%lx local_dst=%lx size=%zu", + dst_rank, remote_addr, local_addr, size); - int ret = backend_->rdma_read(dst_rank, local_addr, dst_ptr, size); + int ret = backend_->rdma_read(dst_rank, local_addr, remote_addr, size); if (ret != 0) { LOG_ERROR("RDMA read failed: dst=%d size=%lu", dst_rank, size); } else { // Poll for completion int n = 0; - for (int attempt = 0; attempt < 100; attempt++) { + int max_attempts = get_max_poll_attempts(); + for (int attempt = 0; attempt < max_attempts; attempt++) { n = backend_->poll_cq(dst_rank, 1); if (n > 0) break; std::this_thread::sleep_for(std::chrono::microseconds(10)); @@ -268,6 +344,113 @@ class rdma_proxy { break; } + case rdma::operation_type::ATOMIC_ADD: { + // Atomic add: fetch-and-add operation + // src_ptr = local result buffer, dst_ptr = remote target, atomic_operand = value to add + void* result_addr = (void*)src_ptr; + uint64_t operand = item.atomic_operand; + + LOG_DEBUG("ATOMIC_ADD: rank=%d dst=%lx operand=%lu result_buf=%lx size=%zu", + dst_rank, dst_ptr, operand, src_ptr, size); + + // Local atomics should be handled directly by the GPU kernel, not offloaded to CPU + if (dst_rank == backend_->get_rank()) { + LOG_ERROR("ERROR: Local atomic operation detected (rank %d -> rank %d). " + "Local atomics should be handled directly in the Triton kernel, " + "not offloaded through the RDMA queue!", + backend_->get_rank(), dst_rank); + queue_->pop(); + break; + } + + // Remote atomic - use RDMA + int ret = backend_->rdma_atomic_fetch_add(dst_rank, result_addr, dst_ptr, operand, size); + if (ret != 0) { + LOG_ERROR("RDMA atomic add failed: dst=%d size=%lu ret=%d", dst_rank, size, ret); + } else { + LOG_DEBUG("RDMA atomic add posted successfully, polling for completion..."); + // Poll for completion + int max_attempts = get_max_poll_attempts(); + int n = 0; + for (int attempt = 0; attempt < max_attempts; attempt++) { + n = backend_->poll_cq(dst_rank, 1); + if (n > 0) { + LOG_DEBUG("ATOMIC_ADD completed after %d attempts, completions=%d", attempt+1, n); + break; + } + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + LOG_ERROR("Warning: ATOMIC_ADD completion not polled after %d attempts!", max_attempts); + } + } + + queue_->pop(); + break; + } + + case rdma::operation_type::ATOMIC_EXCH: { + // Atomic exchange: swap operation + // src_ptr = local result buffer, dst_ptr = remote target, atomic_operand = new value + void* result_addr = (void*)src_ptr; + uint64_t new_value = item.atomic_operand; + + LOG_DEBUG("ATOMIC_EXCH: rank=%d dst=%lx new_val=%lu result_buf=%lx size=%zu", + dst_rank, dst_ptr, new_value, src_ptr, size); + + int ret = backend_->rdma_atomic_exchange(dst_rank, result_addr, dst_ptr, new_value, size); + if (ret != 0) { + LOG_ERROR("RDMA atomic exchange failed: dst=%d size=%lu", dst_rank, size); + } else { + // Poll for completion + int n = 0; + int max_attempts = get_max_poll_attempts(); + for (int attempt = 0; attempt < max_attempts; attempt++) { + n = backend_->poll_cq(dst_rank, 1); + if (n > 0) break; + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + LOG_DEBUG("Warning: ATOMIC_EXCH completion not polled (may be OK if async)"); + } + } + + queue_->pop(); + break; + } + + case rdma::operation_type::ATOMIC_CAS: { + // Atomic compare-and-swap + // src_ptr = local result buffer, dst_ptr = remote target, + // atomic_compare = expected value, atomic_operand = new value + void* result_addr = (void*)src_ptr; + uint64_t compare = item.atomic_compare; + uint64_t swap = item.atomic_operand; + + LOG_DEBUG("ATOMIC_CAS: rank=%d dst=%lx compare=%lu swap=%lu result_buf=%lx size=%zu", + dst_rank, dst_ptr, compare, swap, src_ptr, size); + + int ret = backend_->rdma_atomic_compare_swap(dst_rank, result_addr, dst_ptr, compare, swap, size); + if (ret != 0) { + LOG_ERROR("RDMA atomic CAS failed: dst=%d size=%lu", dst_rank, size); + } else { + // Poll for completion + int n = 0; + int max_attempts = get_max_poll_attempts(); + for (int attempt = 0; attempt < max_attempts; attempt++) { + n = backend_->poll_cq(dst_rank, 1); + if (n > 0) break; + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + LOG_DEBUG("Warning: ATOMIC_CAS completion not polled (may be OK if async)"); + } + } + + queue_->pop(); + break; + } + default: LOG_ERROR("Unknown operation type: %d", item.header.op_type); queue_->pop(); diff --git a/iris/experimental/iris_rdma/src/logging.hpp b/iris/experimental/iris_rdma/src/logging.hpp index 80a968f9..d21eaec6 100644 --- a/iris/experimental/iris_rdma/src/logging.hpp +++ b/iris/experimental/iris_rdma/src/logging.hpp @@ -62,6 +62,29 @@ inline void log_message(log_level level, const char* level_str, const char* fmt, fflush(stderr); } +// Internal logging function with rank +inline void log_message_rank(int rank, log_level level, const char* level_str, const char* fmt, ...) { + if (level < get_log_level()) return; + + // Get timestamp + time_t now = time(nullptr); + struct tm* tm_info = localtime(&now); + char time_buf[64]; + strftime(time_buf, sizeof(time_buf), "%Y-%m-%d %H:%M:%S", tm_info); + + // Print level, timestamp, and rank + fprintf(stderr, "[%s] [%s] [RANK %d] ", time_buf, level_str, rank); + + // Print message + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + + fprintf(stderr, "\n"); + fflush(stderr); +} + } // namespace rdma } // namespace iris @@ -78,6 +101,19 @@ inline void log_message(log_level level, const char* level_str, const char* fmt, #define LOG_ERROR(fmt, ...) \ iris::rdma::log_message(iris::rdma::log_level::ERROR, "ERROR", fmt, ##__VA_ARGS__) +// Rank-aware logging macros +#define LOG_DEBUG_RANK(rank, fmt, ...) \ + iris::rdma::log_message_rank(rank, iris::rdma::log_level::DEBUG, "DEBUG", fmt, ##__VA_ARGS__) + +#define LOG_INFO_RANK(rank, fmt, ...) \ + iris::rdma::log_message_rank(rank, iris::rdma::log_level::INFO, "INFO", fmt, ##__VA_ARGS__) + +#define LOG_WARN_RANK(rank, fmt, ...) \ + iris::rdma::log_message_rank(rank, iris::rdma::log_level::WARN, "WARN", fmt, ##__VA_ARGS__) + +#define LOG_ERROR_RANK(rank, fmt, ...) \ + iris::rdma::log_message_rank(rank, iris::rdma::log_level::ERROR, "ERROR", fmt, ##__VA_ARGS__) + // For data debugging (separate from regular logging) #define LOG_DATA_DEBUG(fmt, ...) \ do { \ diff --git a/iris/experimental/iris_rdma/src/network_backend.hpp b/iris/experimental/iris_rdma/src/network_backend.hpp index 46d84bb4..0405665e 100644 --- a/iris/experimental/iris_rdma/src/network_backend.hpp +++ b/iris/experimental/iris_rdma/src/network_backend.hpp @@ -49,7 +49,6 @@ class network_backend { requested_dev_(device_name), context_(nullptr), pd_orig_(nullptr), - pd_parent_(nullptr), vendor_(rdma::nic_vendor::NONE), port_(1), gid_index_(0), @@ -86,11 +85,6 @@ class network_backend { heap_mr_ = nullptr; } - if (pd_parent_) { - ibv_dealloc_pd(pd_parent_); - pd_parent_ = nullptr; - } - if (pd_orig_) { ibv_dealloc_pd(pd_orig_); pd_orig_ = nullptr; @@ -332,6 +326,133 @@ class network_backend { return ret; } + /** + * @brief RDMA atomic fetch-and-add operation + * @param dst_rank Destination rank + * @param result_addr Local buffer to store the original value + * @param remote_addr Remote address to perform atomic add on + * @param add_value Value to add + * @param size Size in bytes (must be 4 or 8) + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdma_atomic_fetch_add(int dst_rank, void* result_addr, uint64_t remote_addr, + uint64_t add_value, size_t size, uint64_t wr_id = 0) { + queue_pair* qp = get_qp(dst_rank); + if (!qp) { + return -1; + } + + if (size != 4 && size != 8) { + LOG_ERROR("Atomic operations only support 4 or 8 byte sizes, got %zu", size); + return -1; + } + + struct ibv_sge sge; + sge.addr = (uintptr_t)result_addr; + sge.length = size; + sge.lkey = qp->get_lkey(); + + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = wr_id; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.atomic.remote_addr = remote_addr; + wr.wr.atomic.rkey = qp->get_rkey(); + wr.wr.atomic.compare_add = add_value; // Value to add + + struct ibv_send_wr* bad_wr = nullptr; + int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); + + LOG_DEBUG("RDMA Atomic Fetch-Add to rank %d: result=%p remote=%lx add=%lu size=%zu ret=%d", + dst_rank, result_addr, remote_addr, add_value, size, ret); + + return ret; + } + + /** + * @brief RDMA atomic exchange (swap) operation + * @param dst_rank Destination rank + * @param result_addr Local buffer to store the original value + * @param remote_addr Remote address to exchange + * @param new_value New value to write + * @param size Size in bytes (must be 4 or 8) + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdma_atomic_exchange(int dst_rank, void* result_addr, uint64_t remote_addr, + uint64_t new_value, size_t size, uint64_t wr_id = 0) { + queue_pair* qp = get_qp(dst_rank); + if (!qp) { + return -1; + } + + if (size != 4 && size != 8) { + LOG_ERROR("Atomic operations only support 4 or 8 byte sizes, got %zu", size); + return -1; + } + + // For exchange, we need a staging buffer for the new value + // ibverbs doesn't have a direct exchange, so we use CAS in a loop + // But for simplicity, we can use MLX5 extended atomics if available + // For now, we'll return an error and note this needs vendor-specific support + LOG_ERROR("RDMA atomic exchange not yet implemented - needs vendor-specific support"); + return -1; + } + + /** + * @brief RDMA atomic compare-and-swap operation + * @param dst_rank Destination rank + * @param result_addr Local buffer to store the original value + * @param remote_addr Remote address to perform CAS on + * @param compare_value Expected value + * @param swap_value Value to swap in if comparison succeeds + * @param size Size in bytes (must be 4 or 8) + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdma_atomic_compare_swap(int dst_rank, void* result_addr, uint64_t remote_addr, + uint64_t compare_value, uint64_t swap_value, + size_t size, uint64_t wr_id = 0) { + queue_pair* qp = get_qp(dst_rank); + if (!qp) { + return -1; + } + + if (size != 4 && size != 8) { + LOG_ERROR("Atomic operations only support 4 or 8 byte sizes, got %zu", size); + return -1; + } + + struct ibv_sge sge; + sge.addr = (uintptr_t)result_addr; + sge.length = size; + sge.lkey = qp->get_lkey(); + + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = wr_id; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_ATOMIC_CMP_AND_SWP; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.atomic.remote_addr = remote_addr; + wr.wr.atomic.rkey = qp->get_rkey(); + wr.wr.atomic.compare_add = compare_value; // Expected value + wr.wr.atomic.swap = swap_value; // New value if compare succeeds + + struct ibv_send_wr* bad_wr = nullptr; + int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); + + LOG_DEBUG("RDMA Atomic CAS to rank %d: result=%p remote=%lx compare=%lu swap=%lu size=%zu ret=%d", + dst_rank, result_addr, remote_addr, compare_value, swap_value, size, ret); + + return ret; + } + /** * @brief Poll completion queue for RDMA operations * @param dst_rank Destination rank (to poll specific CQ) @@ -349,20 +470,25 @@ class network_backend { int n = ibv_poll_cq(qp->get_ibv_cq(), num_to_poll, wc); if (n < 0) { - LOG_ERROR("CQ poll error for rank %d", dst_rank); + LOG_ERROR_RANK(rank_, "CQ poll error for QP to rank %d", dst_rank); return n; } // Check for errors in completions for (int i = 0; i < n; i++) { if (wc[i].status != IBV_WC_SUCCESS) { - fprintf(stderr, "[ERROR] Work completion failed: status=%d (%s) wr_id=%lu\n", - wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].wr_id); + LOG_ERROR_RANK(rank_, "Work completion failed: status=%d (%s) opcode=%d wr_id=%lu dst_rank=%d", + wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].opcode, wc[i].wr_id, dst_rank); return -1; } + LOG_DEBUG_RANK(rank_, "Work completion: status=SUCCESS opcode=%d wr_id=%lu dst_rank=%d", + wc[i].opcode, wc[i].wr_id, dst_rank); } - LOG_DEBUG("Polled %d completions from rank %d", n, dst_rank); + // Dump CQ info and check its healthy + qp->dump_cq_info(); + + //LOG_DEBUG_RANK(rank_, "Polled %d completions from CQ for QP to rank %d", n, dst_rank); return n; } @@ -378,7 +504,7 @@ class network_backend { const char* requested_dev_; struct ibv_context* context_; struct ibv_pd* pd_orig_; - struct ibv_pd* pd_parent_; // For MLX5/IONIC + struct ibv_device_attr device_attr_; rdma::nic_vendor vendor_; // Port configuration @@ -491,18 +617,19 @@ class network_backend { rdma::dump_ibv_context(context_); rdma::dump_ibv_device(context_->device); + // Query device attributes (needed for atomic operations) + int err = ibv_query_device(context_, &device_attr_); + CHECK_ZERO(err, "ibv_query_device"); + LOG_DEBUG("Device attributes: max_qp_rd_atom=%d max_qp_init_rd_atom=%d", + device_attr_.max_qp_rd_atom, device_attr_.max_qp_init_rd_atom); + // Allocate protection domain pd_orig_ = ibv_alloc_pd(context_); CHECK_NNULL(pd_orig_, "ibv_alloc_pd"); rdma::dump_ibv_pd(pd_orig_); - // Create parent domain for MLX5/IONIC - if (vendor_ == rdma::nic_vendor::MLX5) { - create_parent_domain(); - } - // Query port - int err = ibv_query_port(context_, port_, &portinfo_); + err = ibv_query_port(context_, port_, &portinfo_); CHECK_ZERO(err, "ibv_query_port"); rdma::dump_ibv_port_attr(&portinfo_); @@ -515,21 +642,6 @@ class network_backend { ibv_get_device_name(context_->device)); } - void create_parent_domain() { - LOG_DEBUG("Creating parent domain..."); - - struct ibv_parent_domain_init_attr pattr; - memset(&pattr, 0, sizeof(pattr)); - - pattr.pd = pd_orig_; - pattr.td = nullptr; - pattr.comp_mask = 0; - - pd_parent_ = ibv_alloc_parent_domain(context_, &pattr); - CHECK_NNULL(pd_parent_, "ibv_alloc_parent_domain"); - rdma::dump_ibv_pd(pd_parent_); - } - void select_gid_index() { LOG_DEBUG("Selecting GID index..."); @@ -575,7 +687,7 @@ class network_backend { LOG_DEBUG("Creating queues..."); int ncqes = 64; // Number of CQ entries - int sq_length = 64; // Send queue length + int sq_length = 64; // Send queue length // TODO: FIX THAT // Resize vectors dest_info_.resize(world_size_); @@ -592,33 +704,16 @@ class network_backend { void create_cqs(int ncqes) { LOG_DEBUG("Creating completion queues: ncqes=%d", ncqes); - struct ibv_cq_init_attr_ex cq_attr; - memset(&cq_attr, 0, sizeof(cq_attr)); - - cq_attr.cqe = ncqes; - cq_attr.cq_context = nullptr; - cq_attr.channel = nullptr; - cq_attr.comp_vector = 0; - cq_attr.flags = 0; - - if (pd_parent_) { - cq_attr.comp_mask = IBV_CQ_INIT_ATTR_MASK_PD; - cq_attr.parent_domain = pd_parent_; - } - for (int i = 0; i < world_size_; i++) { - struct ibv_cq_ex* cq_ex = ibv_create_cq_ex(context_, &cq_attr); - CHECK_NNULL(cq_ex, "ibv_create_cq_ex"); - - cqs_[i] = ibv_cq_ex_to_cq(cq_ex); - CHECK_NNULL(cqs_[i], "ibv_cq_ex_to_cq"); + cqs_[i] = ibv_create_cq(context_, ncqes, nullptr, nullptr, 0); + CHECK_NNULL(cqs_[i], "ibv_create_cq"); } } void create_qps(int sq_length) { LOG_DEBUG("Creating queue pairs: sq_length=%d", sq_length); - struct ibv_qp_init_attr_ex attr; + struct ibv_qp_init_attr attr; memset(&attr, 0, sizeof(attr)); attr.cap.max_send_wr = sq_length; @@ -626,15 +721,13 @@ class network_backend { attr.cap.max_inline_data = 8; attr.sq_sig_all = 0; attr.qp_type = IBV_QPT_RC; - attr.comp_mask = IBV_QP_INIT_ATTR_PD; - attr.pd = pd_parent_ ? pd_parent_ : pd_orig_; for (int i = 0; i < world_size_; i++) { attr.send_cq = cqs_[i]; attr.recv_cq = cqs_[i]; - struct ibv_qp* qp = ibv_create_qp_ex(context_, &attr); - CHECK_NNULL(qp, "ibv_create_qp_ex"); + struct ibv_qp* qp = ibv_create_qp(pd_orig_, &attr); + CHECK_NNULL(qp, "ibv_create_qp"); qps_[i] = std::make_unique(qp, cqs_[i], i, vendor_); } @@ -687,7 +780,7 @@ class network_backend { attr.qp_state = IBV_QPS_RTR; attr.path_mtu = portinfo_.active_mtu; attr.min_rnr_timer = 12; - attr.max_dest_rd_atomic = 1; + attr.max_dest_rd_atomic = device_attr_.max_qp_rd_atom; // Use device capability attr.ah_attr.port_num = port_; if (portinfo_.link_layer == IBV_LINK_LAYER_ETHERNET) { @@ -727,7 +820,7 @@ class network_backend { attr.timeout = 14; attr.retry_cnt = 7; attr.rnr_retry = 7; - attr.max_rd_atomic = 1; + attr.max_rd_atomic = device_attr_.max_qp_init_rd_atom; // Use device capability int attr_mask = IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY; diff --git a/iris/experimental/iris_rdma/src/queue.hpp b/iris/experimental/iris_rdma/src/queue.hpp index f52339e6..0d3b30c2 100644 --- a/iris/experimental/iris_rdma/src/queue.hpp +++ b/iris/experimental/iris_rdma/src/queue.hpp @@ -19,9 +19,12 @@ namespace rdma { // Operation types - simplified for Iris enum class operation_type : uint8_t { NOP = 0, - PUT = 1, // RDMA write - GET = 2, // RDMA read - FLUSH = 3, // Flush connection + PUT = 1, // RDMA write + GET = 2, // RDMA read + FLUSH = 3, // Flush connection + ATOMIC_ADD = 4, // Atomic add + ATOMIC_EXCH = 5, // Atomic exchange + ATOMIC_CAS = 6, // Atomic compare-and-swap }; // Work item structure - metadata only, no data storage @@ -37,7 +40,11 @@ struct alignas(16) work_item_header_t { // Note: Completion is signaled by tail pointer advancement, not a flag struct alignas(16) work_item_t { - work_item_header_t header; + work_item_header_t header; // 32 bytes (0-31, padded due to alignas(16)) + // For atomic operations: operand values + uint64_t atomic_operand; // Value to add/exchange (offset 32) + uint64_t atomic_compare; // For CAS: compare value (offset 40) + // Total size: 48 bytes }; // Queue state visible to both CPU and GPU diff --git a/iris/experimental/iris_rdma/src/queue_pair.hpp b/iris/experimental/iris_rdma/src/queue_pair.hpp index e085729a..b3fea24a 100644 --- a/iris/experimental/iris_rdma/src/queue_pair.hpp +++ b/iris/experimental/iris_rdma/src/queue_pair.hpp @@ -97,6 +97,17 @@ class queue_pair { return info; } + + void dump_cq_info() const { + LOG_DEBUG("cq: %p", cq_); + LOG_DEBUG("handle: %u", cq_->channel); + LOG_DEBUG("cq_context: %p", cq_->cq_context); + LOG_DEBUG("context: %p", cq_->context); + LOG_DEBUG("cqe: %u", cq_->cqe); + LOG_DEBUG("comp_events_completed: %u", cq_->comp_events_completed); + LOG_DEBUG("async_events_completed: %u", cq_->async_events_completed); + + } private: struct ibv_qp* qp_; struct ibv_cq* cq_; diff --git a/rebuild.sh b/rebuild.sh new file mode 100755 index 00000000..9aa6c1d9 --- /dev/null +++ b/rebuild.sh @@ -0,0 +1,4 @@ +#!/bin/bash + + +pip install -e . --no-build-isolation \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100755 index 00000000..0563b8b5 --- /dev/null +++ b/run.sh @@ -0,0 +1,7 @@ +#!/bin/bash + + +export IRIS_RDMA_POLL_MAX_ATTEMPTS=1000 +export IRIS_LOG_LEVEL=DEBUG +export IRIS_DEBUG_DATA=1 +torchrun --nproc_per_node=2 examples/24_rdma_atomic_add/rdma_atomic_add.py \ No newline at end of file