diff --git a/docker/Dockerfile b/docker/Dockerfile index 8b49c01a..ba69c823 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,55 +1,71 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. - -FROM rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch +FROM rocm/pytorch:rocm7.0_ubuntu22.04_py3.10_pytorch_release_2.8.0 # Use bash shell for RUN commands SHELL ["/bin/bash", "-c"] # Set environment variables -ENV TRITON_PATH=/opt/triton \ - ROCM_PATH=/opt/rocm \ - OMPI_MCA_mtl="^ofi" \ - OMPI_MCA_pml="ob1" +ENV ROCM_PATH=/opt/rocm ENV LD_LIBRARY_PATH=$ROCM_PATH/lib:$LD_LIBRARY_PATH \ PATH="$ROCM_PATH/bin:$PATH" -ENV OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ - OMPI_ALLOW_RUN_AS_ROOT=1 - -# Install system packages +# Install system packages needed for Iris RDMA RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y \ - git wget ninja-build cmake python3-pip python3-dev build-essential && \ - rm -rf /var/lib/apt/lists/* + git wget cmake build-essential \ + libibverbs-dev librdmacm-dev \ + python3-pip python3-dev \ + infiniband-diags \ + perftest \ + && rm -rf /var/lib/apt/lists/* -# Install Python packages with pip +# Install Python packages RUN pip3 install --upgrade pip && \ - pip3 install wheel jupyter - -# Clone and install Triton -WORKDIR $TRITON_PATH -RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH -RUN git checkout dd5823453bcc7973eabadb65f9d827c43281c434 -RUN pip3 install -e . -ENV PYTHONPATH=$TRITON_PATH + pip3 install pybind11 -# Install rocprofiler-systems +# Set working directory WORKDIR /workspace -RUN wget https://github.com/ROCm/rocprofiler-systems/releases/download/rocm-6.3.1/rocprofiler-systems-install.py && \ - python3 ./rocprofiler-systems-install.py --prefix /opt/rocprofiler-systems --rocm 6.3 && \ - rm -f rocprofiler-systems-install.py # Create entrypoint script -RUN echo '#!/bin/bash' > /entrypoint.sh && \ - echo 'echo "Welcome to the ROCm-aware Docker image!"' >> /entrypoint.sh && \ - echo 'if [ $# -eq 0 ]; then' >> /entrypoint.sh && \ - echo ' exec /bin/bash' >> /entrypoint.sh && \ - echo 'else' >> /entrypoint.sh && \ - echo ' exec "$@"' >> /entrypoint.sh && \ - echo 'fi' >> /entrypoint.sh && \ - chmod +x /entrypoint.sh +RUN printf '#!/bin/bash\n\ +echo "=== Iris RDMA Development Environment ==="\n\ +echo "ROCm version: $(cat $ROCM_PATH/.info/version 2>/dev/null || echo unknown)"\n\ +echo "PyTorch version: $(python -c '\''import torch; print(torch.__version__)'\'' 2>/dev/null)"\n\ +\n\ +# GPU detection using PyTorch\n\ +python -c '\''\n\ +import torch\n\ +if torch.cuda.is_available():\n\ + count = torch.cuda.device_count()\n\ + print(f"GPUs available: {count}")\n\ + for i in range(count):\n\ + name = torch.cuda.get_device_name(i)\n\ + print(f" GPU[{i}]: {name}")\n\ +else:\n\ + print("GPUs available: 0")\n\ +'\'' 2>/dev/null || echo "GPUs available: 0"\n\ +\n\ +# InfiniBand detection\n\ +if [ -d /dev/infiniband ]; then\n\ + IB_COUNT=$(ls /dev/infiniband/uverbs* 2>/dev/null | wc -l)\n\ + echo "InfiniBand devices available: $IB_COUNT"\n\ + if [ $IB_COUNT -gt 0 ]; then\n\ + echo "InfiniBand device(s): $(ls /sys/class/infiniband/ 2>/dev/null | tr '\''\n'\'' '\'' '\'')"\n\ + fi\n\ +else\n\ + echo "InfiniBand devices available: 0"\n\ +fi\n\ +echo "======================================"\n\ +if [ $# -eq 0 ]; then\n\ + exec /bin/bash\n\ +else\n\ + exec "$@"\n\ +fi\n' > /entrypoint.sh + +RUN chmod +x /entrypoint.sh # Set the entrypoint -ENTRYPOINT ["/bin/bash", "-c", "source /entrypoint.sh && exec bash"] \ No newline at end of file +ENTRYPOINT ["/entrypoint.sh"] +CMD ["/bin/bash"] + diff --git a/docker/build.sh b/docker/build.sh index 973c9366..d86bf5a7 100755 --- a/docker/build.sh +++ b/docker/build.sh @@ -1,13 +1,13 @@ #!/bin/bash -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Build miniQP Docker image SCRIPT_DIR=$(dirname "$(realpath "$0")") - -IMAGE_NAME=${1:-"iris-dev"} +IMAGE_NAME=${1:-"iris-rdma"} pushd "$SCRIPT_DIR" > /dev/null -docker build -t $IMAGE_NAME . +echo "Building Docker image: $IMAGE_NAME" +docker build -t $IMAGE_NAME --network=host . popd > /dev/null + diff --git a/docker/run.sh b/docker/run.sh index c967875b..0864bab4 100755 --- a/docker/run.sh +++ b/docker/run.sh @@ -1,14 +1,42 @@ #!/bin/bash -# SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Run Iris RDMA Docker container with InfiniBand support +IMAGE_NAME=${1:-"iris-rdma"} +WORKSPACE_DIR=$(cd "$(dirname "$0")/.." && pwd) -IMAGE_NAME=${1:-"iris-dev"} -WORKSPACE_DIR=${2:-"$(pwd)"} +echo "Starting miniQP container..." +echo " Image: $IMAGE_NAME" +echo " Workspace: $WORKSPACE_DIR" + +# Auto-detect InfiniBand devices +IB_DEVICES="" +if [ -d /dev/infiniband ]; then + for dev in /dev/infiniband/uverbs*; do + if [ -e "$dev" ]; then + IB_DEVICES="$IB_DEVICES --device=$dev" + fi + done + if [ -n "$IB_DEVICES" ]; then + echo " InfiniBand devices: $(ls /dev/infiniband/uverbs* 2>/dev/null | wc -l) found" + fi +else + echo " Warning: No InfiniBand devices found" +fi +echo "" + +docker run -it --rm \ + --network=host \ + --device=/dev/kfd \ + --device=/dev/dri \ + $IB_DEVICES \ + --group-add video \ + --cap-add=SYS_PTRACE \ + --cap-add=IPC_LOCK \ + --security-opt seccomp=unconfined \ + -v "$WORKSPACE_DIR:$WORKSPACE_DIR" \ + -w "$WORKSPACE_DIR" \ + --shm-size=16G \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + $IMAGE_NAME -docker run -it --network=host --device=/dev/kfd\ - --device=/dev/dri --group-add video\ - --cap-add=SYS_PTRACE --security-opt seccomp=unconfined\ - -v "$WORKSPACE_DIR:$WORKSPACE_DIR" -w "$WORKSPACE_DIR"\ - --shm-size=16G --ulimit memlock=-1\ - --ulimit stack=67108864 $IMAGE_NAME diff --git a/examples/22_rdma_producer_consumer/README.md b/examples/22_rdma_producer_consumer/README.md new file mode 100644 index 00000000..1d9139d4 --- /dev/null +++ b/examples/22_rdma_producer_consumer/README.md @@ -0,0 +1,69 @@ +# 22. RDMA Producer-Consumer + +Producer-consumer pattern using InfiniBand RDMA for multi-node communication. + +## Overview + +This example demonstrates: +- Producer Triton kernel generates data on Rank 0 +- RDMA transfer from Rank 0 to Rank 1 +- Consumer Triton kernel verifies data on Rank 1 + +## Requirements + +- InfiniBand network adapter +- libibverbs-dev installed +- Iris built with RDMA support + +## Architecture + +``` +Rank 0 (Producer) Rank 1 (Consumer) +───────────────── ───────────────── +producer_kernel() + ↓ writes +GPU → CPU buffer + ↓ +RDMA PUT ──────────────────→ CPU buffer + ↓ + CPU → GPU + ↓ + consumer_kernel() + ↓ verifies + ✓ Success +``` + +## Usage + +### Single Node (2 GPUs) +```bash +torchrun --nproc_per_node=2 rdma_producer_consumer.py +``` + +### Multi-Node (2 Nodes, 1 GPU each) +```bash +# Node 0 +torchrun --nnodes=2 --nproc_per_node=1 --node_rank=0 \ + --master_addr= --master_port=29500 \ + rdma_producer_consumer.py + +# Node 1 +torchrun --nnodes=2 --nproc_per_node=1 --node_rank=1 \ + --master_addr= --master_port=29500 \ + rdma_producer_consumer.py +``` + +## Expected Output + +``` +[Rank 0/2] Initialized on cuda:0 +[Rank 1/2] Initialized on cuda:1 +[Rank 0] Producing data +[Rank 0] First 10: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] +[Rank 0] RDMA transfer to Rank 1 +[Rank 0] RDMA completed +[Rank 1] Consuming data +[Rank 1] Received first 10: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] +[Rank 1] Verified: 4096/4096 +[Rank 1] SUCCESS! +``` diff --git a/examples/22_rdma_producer_consumer/rdma_producer_consumer.py b/examples/22_rdma_producer_consumer/rdma_producer_consumer.py new file mode 100755 index 00000000..d6e12363 --- /dev/null +++ b/examples/22_rdma_producer_consumer/rdma_producer_consumer.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import os +import sys +import torch +import torch.distributed as dist +import triton +import triton.language as tl +import time + +import iris.experimental.iris_rdma as iris_rdma + + +@triton.jit +def producer_put_kernel( + buffer_ptr, + n_elements, + dst_rank: tl.constexpr, + device_ctx, + BLOCK_SIZE: tl.constexpr, +): + """ + Producer kernel that enqueues RDMA put operations. + Data must already be in buffer_ptr (filled by fill_data_kernel). + Uses symmetric heap model: same buffer offset in local and remote heap. + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Src and dst are the same pointer + ptrs = buffer_ptr + offsets + + # Enqueue RDMA operation + iris_rdma.put(ptrs, ptrs, dst_rank, device_ctx, mask) + + +@triton.jit +def consumer_kernel( + input_ptr, + result_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """ + Consumer kernel that verifies received data. + Expected pattern: ascending numbers 0, 1, 2, ..., n_elements-1 + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load received data + data = tl.load(input_ptr + offsets, mask=mask, other=0.0) + + # Check if it matches expected pattern (0, 1, 2, 3, ...) + expected = offsets.to(data.dtype) + is_correct = (data == expected).to(tl.int32) + + tl.store(result_ptr + offsets, is_correct, mask=mask) + + +def main(): + + dtype = torch.bfloat16 + + # Initialize distributed + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + device_id = torch.device(f"cuda:{local_rank}") + + dist.init_process_group( + backend='nccl', + device_id=device_id + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + + if world_size < 2: + print("This example requires at least 2 ranks") + sys.exit(1) + + torch.cuda.set_device(local_rank) + device = f'cuda:{local_rank}' + + print(f"[Rank {rank}/{world_size}] Initialized on {device}") + + # Create Iris RDMA context with queue + heap_size = 1024 * 1024 * 8 # 8MB + queue_size = 512 + ctx = iris_rdma.iris(heap_size=heap_size, queue_size=queue_size) + + print(f"[Rank {rank}] Iris RDMA initialized") + print(f"[Rank {rank}] - Heap base: {ctx.get_heap_base():#x}") + print(f"[Rank {rank}] - Queue ptr: {ctx.get_queue_ptr():#x}") + + # Get device context for Triton kernels + device_ctx = ctx.get_device_context() + + # Allocate buffers in symmetric heap + n_elements = 4091 + BLOCK_SIZE = 256 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + # Allocate on the symmetric heap + local_buffer = ctx.zeros(n_elements, dtype=dtype) + + ctx.barrier() + + # ============================================================ + # PRODUCER (Rank 0): Generate data and RDMA put to Rank 1 + # ============================================================ + if rank == 0: + print(f"\n[Rank 0] === Producer: Generating and Sending Data ===") + dst_rank = 1 + + # Step 1: Fill buffer with data using PyTorch (no race condition) + print(f"[Rank 0] Filling buffer with data using PyTorch...") + local_buffer.copy_(torch.arange(n_elements, dtype=dtype, device=device)) + print(f"[Rank 0] Data filled, first 10: {local_buffer[:10].tolist()}") + + # Step 2: Launch RDMA enqueue kernel (data already in memory) + print(f"[Rank 0] Launching RDMA enqueue kernel...") + producer_put_kernel[grid]( + local_buffer, + n_elements, + dst_rank=dst_rank, + device_ctx=device_ctx, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Wait for GPU to finish enqueueing + torch.cuda.synchronize() + print(f"[Rank 0] RDMA operations enqueued to queue") + + ctx.barrier() + print(f"[Rank {rank}] Barrier complete, all RDMA operations finished") + + # ============================================================ + # CONSUMER (Rank 1): Verify received data + # ============================================================ + if rank == 1: + print(f"\n[Rank 1] === Consumer: Verifying Received Data ===") + + # Show received data + print(f"[Rank 1] Received data first 10: {local_buffer[:10].tolist()}") + + # Verify data (use int32 for result buffer - stores 0 or 1 for correctness) + result_buffer = torch.zeros(n_elements, dtype=torch.int32, device=device) + + consumer_kernel[grid]( + local_buffer, + result_buffer, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + + result_cpu = result_buffer.cpu() + num_correct = result_cpu.sum().item() + num_total = n_elements + + print(f"[Rank 1] Verified: {int(num_correct)}/{num_total}") + + if num_correct == num_total: + print(f"\n" + "="*60) + print(f"[Rank 1] SUCCESS! Data matches perfectly!") + else: + print(f"[Rank 1] FAILED - Data mismatch!") + first_wrong_idx = (result_cpu == 0).nonzero(as_tuple=True)[0] + if len(first_wrong_idx) > 0: + idx = first_wrong_idx[0].item() + print(f"[Rank 1] First wrong at index {idx}") + print(f"[Rank 1] Expected: {idx}") + print(f"[Rank 1] Got: {local_buffer[idx].item()}") + sys.exit(1) + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + diff --git a/examples/23_rdma_consumer_pull/README.md b/examples/23_rdma_consumer_pull/README.md new file mode 100644 index 00000000..288b5014 --- /dev/null +++ b/examples/23_rdma_consumer_pull/README.md @@ -0,0 +1,88 @@ +# 23. RDMA Consumer Pull (GET) + +Consumer-pull pattern using InfiniBand RDMA GET operations for multi-node communication. + +## Overview + +This example demonstrates: +- Rank 1 (Server) prepares data in its heap +- Rank 0 (Client) uses RDMA GET to pull data from Rank 1 +- Triton kernel verifies pulled data on Rank 0 + +**Key Difference from Example 22:** +- **Example 22 (PUT)**: Sender initiates - Rank 0 pushes data to Rank 1 +- **Example 23 (GET)**: Receiver initiates - Rank 0 pulls data from Rank 1 + +## Requirements + +- InfiniBand network adapter +- libibverbs-dev installed +- Iris built with RDMA support + +## Architecture + +``` +Rank 1 (Server) Rank 0 (Client) +─────────────── ─────────────── +Data in heap + ↓ +CPU buffer + RDMA GET ←──────────┐ + │ +CPU buffer ←───────────────────────────────────────┘ + ↓ +CPU → GPU + ↓ +verify_kernel() + ↓ verifies +✓ Success +``` + +## Usage + +### Single Node (2 GPUs) +```bash +torchrun --nproc_per_node=2 rdma_consumer_pull.py +``` + +### Multi-Node (2 Nodes, 1 GPU each) +```bash +# Node 0 (Client - pulls data) +torchrun --nnodes=2 --nproc_per_node=1 --node_rank=0 \ + --master_addr= --master_port=29500 \ + rdma_consumer_pull.py + +# Node 1 (Server - provides data) +torchrun --nnodes=2 --nproc_per_node=1 --node_rank=1 \ + --master_addr= --master_port=29500 \ + rdma_consumer_pull.py +``` + +## Expected Output + +``` +[Rank 0/2] Initialized on cuda:0 +[Rank 1/2] Initialized on cuda:1 +[Rank 1] Server: Preparing Data +[Rank 1] Data ready, first 10: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] +[Rank 0] Client: Pulling Data via RDMA GET +[Rank 0] RDMA GET operations enqueued to queue +[Rank 0] Barrier complete, all RDMA operations finished +[Rank 0] Received data first 10: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] +[Rank 0] Verified: 4091/4091 +============================================================ +[Rank 0] SUCCESS! Data pulled correctly via RDMA GET! +``` + +## RDMA GET vs PUT + +### When to use GET: +- **Consumer-initiated**: Receiver decides when to pull data +- **Pull-based flow control**: Consumer controls rate +- **Useful for**: Demand-driven workloads, load balancing + +### When to use PUT: +- **Producer-initiated**: Sender decides when to push data +- **Push-based flow control**: Producer controls rate +- **Useful for**: Pipeline parallelism, streaming workloads + diff --git a/examples/23_rdma_consumer_pull/rdma_consumer_pull.py b/examples/23_rdma_consumer_pull/rdma_consumer_pull.py new file mode 100644 index 00000000..9c1cf3a1 --- /dev/null +++ b/examples/23_rdma_consumer_pull/rdma_consumer_pull.py @@ -0,0 +1,198 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +import os +import sys +import torch +import torch.distributed as dist +import triton +import triton.language as tl +import time + +import iris.experimental.iris_rdma as iris_rdma + + +@triton.jit +def consumer_get_kernel( + local_ptr, + remote_ptr, + n_elements, + src_rank: tl.constexpr, + device_ctx, + BLOCK_SIZE: tl.constexpr, +): + """ + Consumer kernel that enqueues RDMA get operations to pull data. + Uses symmetric heap model: remote_ptr points to same offset in remote heap. + After RDMA get completes, data will be available at local_ptr. + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Local and remote pointers (same offset in symmetric heap) + local_ptrs = local_ptr + offsets + remote_ptrs = remote_ptr + offsets + + # Enqueue RDMA GET operation: pull from remote to local + iris_rdma.get(local_ptrs, remote_ptrs, src_rank, device_ctx, mask) + + +@triton.jit +def verify_kernel( + input_ptr, + result_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """ + Verification kernel that checks received data. + Expected pattern: ascending numbers 0, 1, 2, ..., n_elements-1 + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load received data + data = tl.load(input_ptr + offsets, mask=mask, other=0.0) + + # Check if it matches expected pattern (0, 1, 2, 3, ...) + expected = offsets.to(data.dtype) + is_correct = (data == expected).to(tl.int32) + + tl.store(result_ptr + offsets, is_correct, mask=mask) + + +def main(): + + dtype = torch.bfloat16 + + # Initialize distributed + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + device_id = torch.device(f"cuda:{local_rank}") + + dist.init_process_group( + backend='nccl', + device_id=device_id + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + + if world_size < 2: + print("This example requires at least 2 ranks") + sys.exit(1) + + torch.cuda.set_device(local_rank) + device = f'cuda:{local_rank}' + + print(f"[Rank {rank}/{world_size}] Initialized on {device}") + + # Create Iris RDMA context with queue + heap_size = 1024 * 1024 * 8 # 8MB + queue_size = 512 + ctx = iris_rdma.iris(heap_size=heap_size, queue_size=queue_size) + + print(f"[Rank {rank}] Iris RDMA initialized") + print(f"[Rank {rank}] - Heap base: {ctx.get_heap_base():#x}") + print(f"[Rank {rank}] - Queue ptr: {ctx.get_queue_ptr():#x}") + + # Get device context for Triton kernels + device_ctx = ctx.get_device_context() + + # Allocate buffers in symmetric heap + n_elements = 4091 + BLOCK_SIZE = 256 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + # Allocate on the symmetric heap + local_buffer = ctx.zeros(n_elements, dtype=dtype) + + ctx.barrier() + + # ============================================================ + # SERVER (Rank 1): Prepare data for RDMA get + # ============================================================ + if rank == 1: + print(f"\n[Rank 1] === Server: Preparing Data ===") + + # Fill buffer with data using PyTorch + print(f"[Rank 1] Filling buffer with data...") + local_buffer.copy_(torch.arange(n_elements, dtype=dtype, device=device)) + torch.cuda.synchronize() + print(f"[Rank 1] Data ready, first 10: {local_buffer[:10].tolist()}") + print(f"[Rank 1] Waiting for client to pull data...") + + # ============================================================ + # CLIENT (Rank 0): Pull data using RDMA get + # ============================================================ + if rank == 0: + print(f"\n[Rank 0] === Client: Pulling Data via RDMA GET ===") + src_rank = 1 + + # Launch RDMA GET enqueue kernel + print(f"[Rank 0] Launching RDMA GET kernel to pull from Rank {src_rank}...") + consumer_get_kernel[grid]( + local_buffer, # local destination + local_buffer, # remote source (same offset in symmetric heap) + n_elements, + src_rank=src_rank, + device_ctx=device_ctx, + BLOCK_SIZE=BLOCK_SIZE, + ) + + # Wait for GPU to finish enqueueing + torch.cuda.synchronize() + print(f"[Rank 0] RDMA GET operations enqueued to queue") + + ctx.barrier() + print(f"[Rank {rank}] Barrier complete, all RDMA operations finished") + + # ============================================================ + # CLIENT (Rank 0): Verify pulled data + # ============================================================ + if rank == 0: + print(f"\n[Rank 0] === Verifying Pulled Data ===") + + # Show received data + print(f"[Rank 0] Received data first 10: {local_buffer[:10].tolist()}") + + # Verify data (use int32 for result buffer - stores 0 or 1 for correctness) + result_buffer = torch.zeros(n_elements, dtype=torch.int32, device=device) + + verify_kernel[grid]( + local_buffer, + result_buffer, + n_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) + + result_cpu = result_buffer.cpu() + num_correct = result_cpu.sum().item() + num_total = n_elements + + print(f"[Rank 0] Verified: {int(num_correct)}/{num_total}") + + if num_correct == num_total: + print(f"\n" + "="*60) + print(f"[Rank 0] SUCCESS! Data pulled correctly via RDMA GET!") + else: + print(f"[Rank 0] FAILED - Data mismatch!") + first_wrong_idx = (result_cpu == 0).nonzero(as_tuple=True)[0] + if len(first_wrong_idx) > 0: + idx = first_wrong_idx[0].item() + print(f"[Rank 0] First wrong at index {idx}") + print(f"[Rank 0] Expected: {idx}") + print(f"[Rank 0] Got: {local_buffer[idx].item()}") + sys.exit(1) + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + diff --git a/examples/24_rdma_atomic_add/README.md b/examples/24_rdma_atomic_add/README.md new file mode 100644 index 00000000..75d0cbf0 --- /dev/null +++ b/examples/24_rdma_atomic_add/README.md @@ -0,0 +1,127 @@ +# RDMA Atomic Add Example + +This example demonstrates RDMA atomic fetch-and-add operations using Iris RDMA. + +## Overview + +In this example: +- **Rank 0** maintains a shared counter in its symmetric heap +- **All ranks** (0 through N-1) atomically increment rank 0's counter +- Each rank adds its own rank number + 1 (i.e., rank 0 adds 1, rank 1 adds 2, etc.) +- The atomic operation returns the old value before incrementing +- Rank 0 verifies the final sum + +## Key Concepts + +### Atomic Fetch-and-Add +```python +iris_rdma.atomic_add( + result_ptr, # Local buffer to store old value + counter_ptr, # Remote counter location (symmetric heap) + increment, # Value to add + dst_rank, # Which rank owns the counter + device_ctx, # Device context + mask, # Triton mask +) +``` + +- **Atomic**: Operation is indivisible - no race conditions +- **Fetch**: Returns the original value before the add +- **Symmetric Heap**: All ranks use same offset, automatically translated + +### Expected Result + +For N ranks, each rank i adds (i+1): +``` +Final counter = 1 + 2 + 3 + ... + N = N × (N+1) / 2 +``` + +For 2 ranks: 1 + 2 = 3 +For 4 ranks: 1 + 2 + 3 + 4 = 10 +For 8 ranks: 1 + 2 + 3 + ... + 8 = 36 + +## Running the Example + +### With 2 ranks: +```bash +torchrun --nproc_per_node=2 examples/24_rdma_atomic_add/rdma_atomic_add.py +``` + +### With 4 ranks: +```bash +torchrun --nproc_per_node=4 examples/24_rdma_atomic_add/rdma_atomic_add.py +``` + +### With debug logging: +```bash +IRIS_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 examples/24_rdma_atomic_add/rdma_atomic_add.py +``` + +## Expected Output + +``` +[Rank 0/2] Initialized on cuda:0 +[Rank 1/2] Initialized on cuda:1 +[Rank 0] Iris RDMA initialized +[Rank 1] Iris RDMA initialized + +[Rank 0] === Testing Atomic Add === +[Rank 0] Initial counter value: 0 +[Rank 0] Waiting for other ranks to increment... + +[Rank 0] Atomically adding 1 to rank 0's counter... +[Rank 1] Atomically adding 2 to rank 0's counter... +[Rank 0] Atomic add completed. Old value was: 0 +[Rank 1] Atomic add completed. Old value was: 1 + +[Rank 0] === Verification === +[Rank 0] Final counter value: 3 +[Rank 0] Expected value: 3 +[Rank 0] Each rank added: [1, 2] + +============================================================ +[Rank 0] SUCCESS! Atomic operations worked correctly! +============================================================ +``` + +## How It Works + +1. **Initialization**: All ranks initialize Iris RDMA with symmetric heaps +2. **Buffer Allocation**: Each rank allocates counter/result buffers at same offset +3. **Atomic Operations**: + - Ranks launch Triton kernels that call `iris_rdma.atomic_add()` + - Triton kernel enqueues atomic operation to device queue + - CPU proxy thread dequeues and executes RDMA atomic via InfiniBand + - Original value is returned to result buffer +4. **Verification**: Rank 0 checks that sum equals expected value + +## Key Features Demonstrated + +- ✅ **RDMA Atomics**: Hardware-level atomic operations over InfiniBand +- ✅ **Symmetric Heap**: Automatic address translation between ranks +- ✅ **Fetch-and-Add**: Returns old value atomically +- ✅ **GPU-initiated**: Triton kernel directly initiates RDMA operations +- ✅ **Zero-copy**: No intermediate buffers or CPU involvement for data path + +## Notes + +- Atomic operations require **64-bit integers** (`torch.int64` or `torch.uint64`) +- 32-bit atomics are also supported by changing the size parameter +- Operations are **synchronous** - kernel waits for completion before returning +- All ranks must allocate buffers at the **same symmetric heap offset** + +## Troubleshooting + +**Counter value is wrong:** +- Check that all ranks successfully performed atomic operations +- Verify InfiniBand connection is working +- Enable debug logging to see RDMA operations + +**Atomics not supported error:** +- Ensure your InfiniBand HCA supports atomic operations +- Most modern Mellanox/NVIDIA and Broadcom NICs support this + +**Hang on barrier:** +- Check that all ranks are running +- Verify NCCL is properly configured + diff --git a/examples/24_rdma_atomic_add/rdma_atomic_add.py b/examples/24_rdma_atomic_add/rdma_atomic_add.py new file mode 100644 index 00000000..824a678e --- /dev/null +++ b/examples/24_rdma_atomic_add/rdma_atomic_add.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +""" +RDMA Atomic Add Example + +Demonstrates RDMA atomic fetch-and-add operations between ranks. +Each rank atomically increments a counter on rank 0. +""" + +import os +import sys +import torch +import torch.distributed as dist +import triton +import triton.language as tl + +from iris.experimental import iris_rdma + + +@triton.jit +def atomic_add_kernel( + counter_ptr, + result_ptr, + increment, + dst_rank: tl.constexpr, + device_ctx, +): + """ + Each thread atomically adds its increment to the remote counter. + Returns the old value before increment. + """ + pid = tl.program_id(0) + + # Only first thread does the atomic add + if pid == 0: + # Create a mask for single element operation + mask = tl.full([1], 1, dtype=tl.int1) + + # Atomic add: increment counter on dst_rank, get old value + iris_rdma.atomic_add( + result_ptr, # Where to store old value + counter_ptr, # Remote counter location (symmetric heap) + increment, # Value to add + dst_rank, # Which rank has the counter + device_ctx, + mask, + ) + + +def main(): + dtype = torch.int64 # Atomics require int64/uint64 + + # Initialize distributed + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + device_id = torch.device(f"cuda:{local_rank}") + + dist.init_process_group( + backend='nccl', + device_id=device_id + ) + + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f"cuda:{rank}") + + print(f"[Rank {rank}/{world_size}] Initialized on {device}") + + # Initialize RDMA context + ctx = iris_rdma.IrisRDMA() + print(f"[Rank {rank}] Iris RDMA initialized") + print(f"[Rank {rank}] - Heap base: {ctx.get_heap_base():#x}") + print(f"[Rank {rank}] - Queue ptr: {ctx.get_queue_ptr():#x}") + + # Get device context for Triton kernels + device_ctx = ctx.get_device_context() + + # Allocate counter and result buffer in symmetric heap + counter = ctx.zeros(1, dtype=dtype) # Shared counter + result = ctx.zeros(1, dtype=dtype) # Store old value + + ctx.barrier() + + # ============================================================ + # Rank 0 atomically increments rank 1's counter + # ============================================================ + print(f"\n[Rank {rank}] === Testing Atomic Add ===") + + if rank == 1: + print(f"[Rank 1] Initial counter value: {counter[0].item()}") + print(f"[Rank 1] Waiting for rank 0 to increment...") + + ctx.barrier() + + # Only rank 0 performs the atomic operation (to avoid local atomic on rank 1) + if rank == 0: + increment = 42 # Arbitrary test value + target_rank = 1 + print(f"[Rank 0] Atomically adding {increment} to rank {target_rank}'s counter...") + + # Launch atomic add kernel + grid = (1,) # Single thread + atomic_add_kernel[grid]( + counter, # Counter location (same offset on all ranks) + result, # Where to store old value + increment, # Value to add + dst_rank=target_rank, + device_ctx=device_ctx, + ) + + # Synchronize GPU + torch.cuda.synchronize() + + # Read the old value returned by atomic add + old_value = result.cpu()[0].item() + print(f"[Rank 0] Atomic add completed. Old value was: {old_value}") + + ctx.barrier() + + # ============================================================ + # Rank 1: Verify final counter value + # ============================================================ + if rank == 1: + print(f"\n[Rank 1] === Verification ===") + final_value = counter.cpu()[0].item() + expected = 42 # Only rank 0 added 42 + + print(f"[Rank 1] Final counter value: {final_value}") + print(f"[Rank 1] Expected value: {expected}") + + if final_value == expected: + print("\n" + "="*60) + print("[Rank 1] SUCCESS! RDMA atomic add worked correctly!") + print("="*60) + else: + print(f"[Rank 1] FAILED - Counter value mismatch!") + print(f"[Rank 1] Expected: {expected}") + print(f"[Rank 1] Got: {final_value}") + sys.exit(1) + + ctx.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + diff --git a/iris/experimental/__init__.py b/iris/experimental/__init__.py index dbab5167..44369673 100644 --- a/iris/experimental/__init__.py +++ b/iris/experimental/__init__.py @@ -9,8 +9,9 @@ Current experimental features: - iris_gluon: Gluon-based implementation using @aggregate and @gluon.jit +- iris_rdma: InfiniBand RDMA backend for multi-node communication -Usage: +Usage (Gluon): >>> import iris.experimental.iris_gluon as iris_gl >>> from triton.experimental import gluon >>> from triton.experimental.gluon import language as gl @@ -24,8 +25,39 @@ >>> def kernel(IrisDeviceCtx: gl.constexpr, context_tensor): >>> ctx = IrisDeviceCtx.initialize(context_tensor) >>> ctx.load(buffer, 1) + +Usage (RDMA): + >>> import iris.experimental.iris_rdma as iris_rdma + >>> import torch.distributed as dist + >>> + >>> # Initialize PyTorch Distributed first + >>> dist.init_process_group(backend='nccl') + >>> + >>> # Host side + >>> ctx = iris_rdma.iris(heap_size=2**30) + >>> device_ctx = ctx.get_device_context() + >>> + >>> # Device side + >>> @triton.jit + >>> def kernel(dst_ptr, data, device_ctx, dst_rank): + >>> iris_rdma.put(dst_ptr, data, dst_rank, device_ctx, mask) """ from . import iris_gluon +# Try to import iris_rdma (optional, requires InfiniBand) +try: + from . import iris_rdma + _has_rdma = True +except ImportError as e: + _has_rdma = False + import warnings + warnings.warn( + f"iris_rdma not available: {e}\n" + "InfiniBand RDMA support requires libibverbs-dev and building with CMake.", + ImportWarning + ) + __all__ = ["iris_gluon"] +if _has_rdma: + __all__.append("iris_rdma") diff --git a/iris/experimental/iris_rdma.py b/iris/experimental/iris_rdma.py new file mode 100644 index 00000000..45b2adb9 --- /dev/null +++ b/iris/experimental/iris_rdma.py @@ -0,0 +1,822 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Iris RDMA: Experimental InfiniBand RDMA Backend for Multi-Node Communication + +This module provides InfiniBand RDMA support for multi-node communication in Iris. +Unlike the main Iris which uses HIP IPC for intra-node GPU communication, this backend +enables inter-node communication via RDMA over InfiniBand. + +Key Features: +- InfiniBand Queue Pair (QP) setup and management +- Symmetric heap with RDMA memory registration +- RDMA put/get operations in Triton kernels +- PyTorch Distributed integration for bootstrapping + +Example: + >>> import iris.experimental.iris_rdma as iris_rdma + >>> import torch.distributed as dist + >>> + >>> # Initialize PyTorch Distributed + >>> dist.init_process_group(backend='nccl') + >>> + >>> # Create RDMA context + >>> ctx = iris_rdma.iris(heap_size=2**30) # 1GB heap + >>> device_ctx = ctx.get_device_context() # For passing to Triton kernels + >>> + >>> @triton.jit + >>> def kernel(dst_ptr, data, device_ctx, dst_rank, BLOCK_SIZE: tl.constexpr): + >>> pid = tl.program_id(0) + >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + >>> + >>> # RDMA put to remote rank + >>> iris_rdma.put(dst_ptr + offsets, data, dst_rank, device_ctx) +""" + +import torch +import torch.distributed as dist +import triton +import triton.language as tl +import numpy as np +import sys +import os + +# Import the C++ backend module +try: + from . import _iris_rdma_backend as backend +except ImportError: + raise ImportError( + "Iris RDMA backend not available. " + "Make sure the module is built with InfiniBand support. " + "Set IRIS_RDMA_DEBUG=1 for more information." + ) + +# Import logging +from ..logging import logger + + +class IrisRDMA: + """ + Main Iris RDMA class for multi-node RDMA operations. + + This class provides a unified interface for RDMA-based communication + across multiple nodes using InfiniBand. + + Args: + heap_size (int): Size of the symmetric heap in bytes. Default: 1GB (2^30) + process_group: PyTorch distributed process group (default: WORLD) + device_name (str): InfiniBand device name (default: auto-detect) + + Example: + >>> ctx = iris_rdma.iris(heap_size=2**31) # 2GB heap + >>> print(f"Rank {ctx.rank} of {ctx.world_size}") + >>> buffer = ctx.zeros(1024, dtype=torch.float32) + """ + + def __init__(self, heap_size=1 << 30, process_group=None, queue_size=512): + # Check if distributed is initialized + if not dist.is_initialized(): + raise RuntimeError( + "PyTorch distributed must be initialized. " + "Call torch.distributed.init_process_group() first." + ) + + if process_group is None: + process_group = dist.group.WORLD + + # Get rank and world size + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.device_id = self.rank % torch.cuda.device_count() + self.device = f"cuda:{self.device_id}" + + torch.cuda.set_device(self.device_id) + + # Create torch_bootstrap + self._bootstrap = backend.torch_bootstrap(process_group) + + # Allocate symmetric heap (CPU pinned memory for now) + # TODO: Support GPU memory with GPUDirect RDMA + self.heap_size = heap_size + self.heap_offset = 0 + self.alignment = 1024 + + # Create GPU memory pool + self.memory_pool = torch.empty(heap_size, device=self.device, dtype=torch.int8) + + self._manager = backend.rdma_proxy(self._bootstrap, self.memory_pool, queue_size) + self._manager.start_proxy_thread() + + self._backend = self._manager + + logger.info(f"[Rank {self.rank}] Using rdma_proxy with queue (size={queue_size})") + + self.remote_heap_bases = [] + for i in range(self.world_size): + self.remote_heap_bases.append(self._manager.get_remote_heap_base(i)) + + logger.info(f"[Rank {self.rank}] Iris RDMA initialized: heap_size={heap_size}, " + f"heap_base={self._manager.get_heap_base():#x}") + + def __del__(self): + """Clean up resources""" + if hasattr(self, '_manager') and self._manager is not None: + self._manager.stop_proxy_thread() + + def get_heap_base(self): + """Get local heap base address""" + return self._manager.get_heap_base() + + def get_queue_ptr(self): + """Get queue pointer for Triton kernels""" + return self._manager.get_queue_ptr() + + def get_device_context(self): + """ + Get device context tensor for passing to Triton kernels. + + The context tensor encodes: + - [0]: current rank + - [1]: world size + - [2]: queue pointer (for enqueueing RDMA operations) + - [3:]: heap base addresses for all ranks + + Returns: + torch.Tensor: Device context tensor (on GPU) + + Example: + >>> ctx = iris_rdma.iris() + >>> device_ctx = ctx.get_device_context() + >>> # Pass device_ctx to Triton kernel + """ + # Create context tensor: [rank, world_size, queue_ptr, heap_base_0, heap_base_1, ...] + context_size = 3 + self.world_size + context = torch.zeros(context_size, dtype=torch.int64, device=self.device) + + context[0] = self.rank + context[1] = self.world_size + context[2] = self.get_queue_ptr() + + for i in range(self.world_size): + context[3 + i] = self.remote_heap_bases[i] + + return context + + def zeros(self, *size, dtype=torch.float32, device=None): + """ + Allocate and initialize a tensor with zeros in the symmetric heap. + + Args: + *size: Tensor dimensions + dtype: Data type (default: torch.float32) + device: Device placement (default: GPU for direct kernel access) + + Returns: + torch.Tensor: Allocated tensor (on GPU by default) + + Example: + >>> buffer = ctx.zeros(1024, 1024, dtype=torch.float32) + """ + if device is None: + device = self.device # Use GPU by default (for GPUDirect) + + # Calculate size in bytes + elem_size = torch.tensor([], dtype=dtype).element_size() + numel = int(np.prod(size)) + size_bytes = numel * elem_size + + # Align allocation + aligned_offset = (self.heap_offset + self.alignment - 1) // self.alignment * self.alignment + + if aligned_offset + size_bytes > self.heap_size: + raise RuntimeError(f"Heap exhausted: requested {size_bytes} bytes, " + f"available {self.heap_size - aligned_offset}") + + # Create tensor view into memory pool + byte_offset = aligned_offset + byte_end = byte_offset + size_bytes + + # Get the memory slice and view as the requested dtype + memory_slice = self.memory_pool[byte_offset:byte_end] + tensor = memory_slice.view(dtype).reshape(size) + + # Zero initialize + tensor.zero_() + + # Update offset + self.heap_offset = byte_end + + logger.debug(f"[Rank {self.rank}] Allocated tensor: size={size}, " + f"offset={byte_offset:#x}, ptr={tensor.data_ptr():#x}") + + return tensor + + def barrier(self): + """ + Synchronize all ranks and drain RDMA queue. + + Waits for: + 1. All enqueued RDMA operations to complete (queue drains) + 2. All ranks to reach this barrier + + Example: + >>> ctx.barrier() # Wait for all ranks and RDMA completion + """ + + # First, synchronize with all GPUs + torch.cuda.synchronize() + + # Then, synchronize with all ranks + dist.barrier() + + # Finally, wait for queue to drain (all work processed) + self.wait_queue_drain() + + def wait_queue_drain(self, timeout=30.0): + """ + Wait for the CPU proxy thread to process all enqueued work items. + + Spins until queue is empty (head == tail), meaning all work has been + processed and popped by the CPU proxy thread. + + Args: + timeout: Maximum time to wait in seconds + + Raises: + TimeoutError: If queue doesn't drain within timeout + """ + import time + start = time.time() + + while time.time() - start < timeout: + # Check if queue is empty (head == tail) + if self._manager.is_queue_empty(): + return + + # Small sleep to avoid burning CPU + time.sleep(0.0001) # 100 microseconds + + raise TimeoutError(f"Queue did not drain within {timeout}s") + + def rdma_put(self, dst_rank, local_addr, remote_addr, size): + """ + Perform RDMA write (put) to remote rank. + + Args: + dst_rank: Destination rank + local_addr: Local buffer address (int or tensor.data_ptr()) + remote_addr: Remote buffer address (int) + size: Size in bytes + + Returns: + int: 0 on success, non-zero on error + + Example: + >>> src = ctx.zeros(1024, dtype=torch.float32) + >>> dst_addr = ctx.remote_heap_bases[1] # Remote rank 1's heap + >>> ctx.rdma_put(1, src.data_ptr(), dst_addr, src.numel() * 4) + """ + if isinstance(local_addr, torch.Tensor): + local_addr = local_addr.data_ptr() + + return self._backend.rdma_write(dst_rank, local_addr, remote_addr, size) + + def rdma_get(self, dst_rank, local_addr, remote_addr, size): + """ + Perform RDMA read (get) from remote rank. + + Args: + dst_rank: Source rank (destination of the QP) + local_addr: Local buffer address (int or tensor.data_ptr()) + remote_addr: Remote buffer address (int) + size: Size in bytes + + Returns: + int: 0 on success, non-zero on error + + Example: + >>> dst = ctx.zeros(1024, dtype=torch.float32) + >>> src_addr = ctx.remote_heap_bases[1] # Remote rank 1's heap + >>> ctx.rdma_get(1, dst.data_ptr(), src_addr, dst.numel() * 4) + """ + if isinstance(local_addr, torch.Tensor): + local_addr = local_addr.data_ptr() + + return self._backend.rdma_read(dst_rank, local_addr, remote_addr, size) + + def poll_completion(self, dst_rank, max_completions=1): + """ + Poll completion queue for RDMA operations. + + Args: + dst_rank: Destination rank (to poll specific CQ) + max_completions: Maximum number of completions to poll + + Returns: + int: Number of completions polled (negative on error) + + Example: + >>> ctx.rdma_put(1, src.data_ptr(), remote_addr, size) + >>> while ctx.poll_completion(1) == 0: + >>> pass # Wait for completion + """ + return self._backend.poll_cq(dst_rank, max_completions) + + def __repr__(self): + return f"" + + +def iris(heap_size=1 << 30, process_group=None, queue_size=512): + """ + Factory function to create Iris RDMA context. + + Args: + heap_size (int): Size of the symmetric heap in bytes + process_group: PyTorch distributed process group + queue_size (int): Queue size for GPU->CPU RDMA operations + + Returns: + IrisRDMA: RDMA context object + + Example: + >>> import iris.experimental.iris_rdma as iris_rdma + >>> ctx = iris_rdma.iris(heap_size=2**30) + """ + return IrisRDMA(heap_size, process_group, queue_size) + + +############################################################################# +# Triton Device-Side APIs +############################################################################# + +@triton.jit +def _translate(ptr, from_rank, to_rank, heap_bases): + """ + Translate a pointer from one rank's address space to another. + + This implements the symmetric heap model where each rank has a heap at + a different base address, but offsets are preserved across ranks. + + Args: + ptr: Pointer in from_rank's address space + from_rank: Source rank ID + to_rank: Target rank ID + heap_bases: Pointer to array of heap base addresses + + Returns: + Translated pointer in to_rank's address space + """ + from_base = tl.load(heap_bases + from_rank) + to_base = tl.load(heap_bases + to_rank) + + # Convert to int to compute difference + ptr_int = ptr.to(tl.uint64) + + # Find the offset from from_rank heap + offset = ptr_int - from_base + + # Byte cast for byte offset addition + to_base_byte = to_base.to(tl.pointer_type(tl.int8)) + + # Find the offset into the to_rank heap + translated_ptr_byte = to_base_byte + offset + + # Cast back to original pointer type + translated_ptr = translated_ptr_byte.to(ptr.dtype) + + return translated_ptr + + +@triton.jit +def _wait_for_completion(queue_ptr, queue_pos): + """ + Wait for CPU to process a queue item. + + Spins until tail pointer advances past our queue position, + indicating the CPU has processed and popped our item. + + Args: + queue_ptr: Queue context pointer + queue_pos: Queue position to wait for (returned from _enqueue_rdma_op) + """ + state_ptr = queue_ptr.to(tl.pointer_type(tl.uint64)) + + # Load tail pointer (offset 2 in QueueState) + # Use volatile and cache modifier to prevent caching + tail_ptr = tl.load(state_ptr + 2, cache_modifier=".cv", volatile=True) + tail_ptr_typed = tail_ptr.to(tl.pointer_type(tl.uint64)) + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + # Spin until CPU advances tail past our position + while queue_pos >= current_tail: + tail_ptr = tl.load(state_ptr + 2, cache_modifier=".cv", volatile=True) + tail_ptr_typed = tail_ptr.to(tl.pointer_type(tl.uint64)) + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + +@triton.jit +def _enqueue_rdma_op(dst_ptr, src_ptr, to_rank: tl.constexpr, op_code: tl.constexpr, queue_ptr, mask): + """ + Internal: Enqueue an RDMA operation to the queue. + + Args: + dst_ptr: Destination pointer on remote rank + src_ptr: Source pointer (local address where data is stored in registered heap) + to_rank: Target rank ID + op_code: Operation type (1=PUT, 2=GET) + queue_ptr: Queue pointer from device context + mask: Triton mask for valid elements + """ + # Queue structure (from queue.hpp): + # struct QueueState { + # WorkItem* items; // offset 0 + # uint64_t* head; // offset 8 + # uint64_t* tail; // offset 16 + # uint64_t* tailCache; // offset 24 + # int32_t size; // offset 32 + # }; + + state_ptr = queue_ptr.to(tl.pointer_type(tl.uint64)) + + # Load QueueState fields + items_ptr = tl.load(state_ptr + 0) + head_ptr = tl.load(state_ptr + 1) + tail_ptr = tl.load(state_ptr + 2) + + # Load size (at offset 32 bytes = 4 * uint64) + size_ptr = queue_ptr.to(tl.pointer_type(tl.int32)) + size = tl.load(size_ptr + 8) + + # Atomic increment head to reserve slot + head_ptr_typed = head_ptr.to(tl.pointer_type(tl.uint64)) + prev_head = tl.atomic_add(head_ptr_typed, 1, sem='relaxed', scope='sys') + + # Wait for slot to be free: spin if prev_head >= size + *tail + size_u64 = size.to(tl.uint64) + tail_ptr_typed = tail_ptr.to(tl.pointer_type(tl.uint64)) + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + while prev_head >= size_u64 + current_tail: + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + # Calculate slot position + slot_idx = prev_head % size_u64 + + # WorkItem structure (32 bytes): + # struct WorkItem { + # uint64_t dst_ptr; // offset 0 + # uint64_t src_ptr; // offset 8 + # uint32_t size_bytes; // offset 16 - WRITE LAST as ready flag + # uint16_t rank; // offset 20 + # uint8_t op_type; // offset 22 + # uint8_t reserved; // offset 23 + # }; + WORK_ITEM_SIZE_BYTES = 32 + + slot_offset_bytes = slot_idx * WORK_ITEM_SIZE_BYTES + + # Get pointer to this work item + items_ptr_u64 = items_ptr.to(tl.pointer_type(tl.uint64)) + slot_ptr_u64 = items_ptr_u64 + (slot_offset_bytes // 8).to(tl.int32) + + # Extract destination address (min of pointer block) + dst_ptr_u64 = dst_ptr.to(tl.uint64) + dst_ptr_val = tl.min(dst_ptr_u64, axis=0) + + # Extract source address (min of pointer block where data is stored) + src_ptr_u64 = src_ptr.to(tl.uint64) + src_ptr_val = tl.min(src_ptr_u64, axis=0) + max_src_ptr = tl.max(src_ptr_u64, axis=0) + + # Infer element size from pointer type + # src_ptr is a block of pointers with a specific element type (e.g., pointer) + # The pointer dtype tells us the element type, which has a known size + # Map Triton dtypes to their byte sizes + ptr_dtype = src_ptr.dtype.element_ty # Get the element type that the pointer points to + + # Get element size in bytes from the dtype + # tl.float16 -> 2, tl.float32 -> 4, tl.float64 -> 8, etc. + if ptr_dtype == tl.float16 or ptr_dtype == tl.bfloat16: + element_size_bytes = 2 + elif ptr_dtype == tl.float32 or ptr_dtype == tl.int32 or ptr_dtype == tl.uint32: + element_size_bytes = 4 + elif ptr_dtype == tl.float64 or ptr_dtype == tl.int64 or ptr_dtype == tl.uint64: + element_size_bytes = 8 + elif ptr_dtype == tl.int8 or ptr_dtype == tl.uint8: + element_size_bytes = 1 + elif ptr_dtype == tl.int16 or ptr_dtype == tl.uint16: + element_size_bytes = 2 + else: + # Default to 4 bytes for unknown types + element_size_bytes = 4 + + # Calculate total size in bytes + # Count number of valid elements based on mask + mask_int = mask.to(tl.int32) + num_elements = tl.sum(mask_int, axis=0) + size_bytes = (num_elements * element_size_bytes).to(tl.uint32) + + # Write header fields (but NOT size_bytes yet - it's the ready flag) + # Write dst_ptr (offset 0) + tl.store(slot_ptr_u64 + 0, dst_ptr_val) + + # Write src_ptr (offset 8) + tl.store(slot_ptr_u64 + 1, src_ptr_val) + + # Write rank + op_type (offset 20-23) + metadata = (to_rank & 0xFFFF) | ((op_code & 0xFF) << 16) + slot_ptr_u32 = slot_ptr_u64.to(tl.pointer_type(tl.uint32)) + tl.store(slot_ptr_u32 + 5, metadata.to(tl.uint32)) + + # Write size_bytes LAST as ready flag (offset 16) + size_bytes_ptr = (slot_ptr_u32 + 4).to(tl.pointer_type(tl.uint32)) + tl.atomic_xchg(size_bytes_ptr, size_bytes, sem='release', scope='sys') + + # Return queue position for waiting + return prev_head + + +@triton.jit +def put(dst_ptr, src_ptr, dst_rank: tl.constexpr, device_ctx, mask): + """ + RDMA put (write) operation from Triton kernel. + + Uses symmetric heap model: dst_ptr is in current rank's address space, + and will be automatically translated to remote rank's address space. + + IMPORTANT: Data must be stored at src_ptr BEFORE calling this function. + This avoids race conditions between GPU writes and CPU RDMA reads. + The CPU proxy thread will dequeue and perform the actual RDMA write. + + Args: + dst_ptr: Destination pointer in CURRENT rank's address space (symmetric heap) + src_ptr: Source pointer (local address in registered heap) where data is already stored - can be block of pointers + dst_rank: Target rank ID (must be compile-time constant) + device_ctx: Device context from iris_rdma.get_device_context() + mask: Triton mask for valid elements + + Example: + >>> @triton.jit + >>> def kernel(local_buffer, device_ctx, dst_rank, BLOCK_SIZE: tl.constexpr): + >>> pid = tl.program_id(0) + >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> + >>> data = generate_data(offsets) + >>> src_ptrs = local_buffer + offsets + >>> dst_ptrs = local_buffer + offsets # Same offset, symmetric heap! + >>> + >>> # Store data FIRST to avoid race condition + >>> tl.store(src_ptrs, data, mask=mask) + >>> + >>> # Then enqueue RDMA operation + >>> iris_rdma.put(dst_ptrs, src_ptrs, dst_rank, device_ctx, mask) + """ + # Extract context fields + # Context format: [rank, world_size, queue_ptr, heap_base_0, heap_base_1, ...] + my_rank = tl.load(device_ctx + 0) + queue_ptr = tl.load(device_ctx + 2) + heap_bases = device_ctx + 3 + + # Translate dst_ptr from current rank's address space to remote rank's + translated_dst_ptr = _translate(dst_ptr, my_rank, dst_rank, heap_bases) + + # Enqueue PUT operation (op_code=1) with translated address + _enqueue_rdma_op(translated_dst_ptr, src_ptr, dst_rank, 1, queue_ptr, mask) + + +@triton.jit +def get(dst_ptr, src_ptr, from_rank: tl.constexpr, device_ctx, mask): + """ + RDMA get (read) operation from Triton kernel. + + Uses symmetric heap model: src_ptr is in current rank's address space, + and will be automatically translated to remote rank's address space. + + Enqueues a request to read data from remote rank via RDMA and WAITS for completion. + The CPU proxy thread will dequeue, perform the RDMA read, then pop the item. + This function spins until the tail pointer advances, then data is ready at dst_ptr. + + Args: + dst_ptr: Local destination pointer where data will be written - can be block of pointers + src_ptr: Source pointer in CURRENT rank's address space (symmetric heap) + from_rank: Source rank ID (must be compile-time constant) + device_ctx: Device context from iris_rdma.get_device_context() + mask: Triton mask for valid elements + + Example: + >>> @triton.jit + >>> def kernel(local_buffer, device_ctx, from_rank, BLOCK_SIZE: tl.constexpr): + >>> pid = tl.program_id(0) + >>> offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> + >>> src_ptrs = local_buffer + offsets # Same offset in symmetric heap! + >>> dst_ptrs = local_buffer + offsets + >>> # RDMA read from remote rank - blocks until complete + >>> iris_rdma.get(dst_ptrs, src_ptrs, from_rank, device_ctx, mask) + >>> + >>> # Data is now ready at dst_ptrs, can use it immediately + >>> data = tl.load(dst_ptrs, mask=mask) + """ + # Extract context fields + # Context format: [rank, world_size, queue_ptr, heap_base_0, heap_base_1, ...] + my_rank = tl.load(device_ctx + 0) + queue_ptr = tl.load(device_ctx + 2) + heap_bases = device_ctx + 3 + + # Translate src_ptr from current rank's address space to remote rank's + translated_src_ptr = _translate(src_ptr, my_rank, from_rank, heap_bases) + + # Enqueue GET operation (op_code=2) + # For GET: translated_src_ptr is remote source, dst_ptr is local destination + queue_pos = _enqueue_rdma_op(translated_src_ptr, dst_ptr, from_rank, 2, queue_ptr, mask) + + # Wait for CPU to complete the RDMA read + _wait_for_completion(queue_ptr, queue_pos) + + # Data is now ready at dst_ptr (CPU has written it there via RDMA) + + +@triton.jit +def atomic_add(result_ptr, dst_ptr, add_value, dst_rank: tl.constexpr, device_ctx, mask): + """ + RDMA atomic fetch-and-add operation from Triton kernel. + + Atomically adds a value to remote memory and returns the original value. + Uses symmetric heap model: dst_ptr is in current rank's address space, + and will be automatically translated to remote rank's address space. + + Args: + result_ptr: Local pointer where the original value will be stored + dst_ptr: Destination pointer in CURRENT rank's address space (symmetric heap) + add_value: Value to add (must be scalar uint64 or int64) + dst_rank: Destination rank ID (must be compile-time constant) + device_ctx: Device context from iris_rdma.get_device_context() + mask: Triton mask for valid elements + + Note: Only supports 8-byte (uint64/int64) atomic operations. + The result_ptr will contain the original value before the add. + """ + # Extract context fields + my_rank = tl.load(device_ctx + 0) + queue_ptr = tl.load(device_ctx + 2) + heap_bases = device_ctx + 3 + + # Translate dst_ptr from current rank's address space to remote rank's + translated_dst_ptr = _translate(dst_ptr, my_rank, dst_rank, heap_bases) + + # Enqueue ATOMIC_ADD operation (op_code=4) + # For ATOMIC_ADD: result_ptr is local result buffer, translated_dst_ptr is remote target + queue_pos = _enqueue_atomic_op(result_ptr, translated_dst_ptr, dst_rank, 4, + add_value, 0, queue_ptr, mask) + + # Wait for CPU to complete the atomic operation + _wait_for_completion(queue_ptr, queue_pos) + + # Result is now ready at result_ptr (original value before add) + + +@triton.jit +def atomic_cas(result_ptr, dst_ptr, compare_value, swap_value, dst_rank: tl.constexpr, device_ctx, mask): + """ + RDMA atomic compare-and-swap operation from Triton kernel. + + Atomically compares remote memory with expected value and swaps if equal. + Returns the original value. Uses symmetric heap model. + + Args: + result_ptr: Local pointer where the original value will be stored + dst_ptr: Destination pointer in CURRENT rank's address space (symmetric heap) + compare_value: Expected value (must be scalar uint64 or int64) + swap_value: New value if comparison succeeds (must be scalar uint64 or int64) + dst_rank: Destination rank ID (must be compile-time constant) + device_ctx: Device context from iris_rdma.get_device_context() + mask: Triton mask for valid elements + + Note: Only supports 8-byte (uint64/int64) atomic operations. + The result_ptr will contain the original value at the remote location. + If result == compare_value, the swap succeeded. + """ + # Extract context fields + my_rank = tl.load(device_ctx + 0) + queue_ptr = tl.load(device_ctx + 2) + heap_bases = device_ctx + 3 + + # Translate dst_ptr from current rank's address space to remote rank's + translated_dst_ptr = _translate(dst_ptr, my_rank, dst_rank, heap_bases) + + # Enqueue ATOMIC_CAS operation (op_code=6) + queue_pos = _enqueue_atomic_op(result_ptr, translated_dst_ptr, dst_rank, 6, + swap_value, compare_value, queue_ptr, mask) + + # Wait for CPU to complete the atomic operation + _wait_for_completion(queue_ptr, queue_pos) + + # Result is now ready at result_ptr (original value from remote) + + +@triton.jit +def _enqueue_atomic_op(result_ptr, dst_ptr, to_rank: tl.constexpr, op_code: tl.constexpr, + operand, compare, queue_ptr, mask): + """ + Internal: Enqueue an atomic RDMA operation to the queue. + + Args: + result_ptr: Local pointer for result + dst_ptr: Destination pointer on remote rank (already translated) + to_rank: Target rank ID + op_code: Operation type (4=ATOMIC_ADD, 6=ATOMIC_CAS) + operand: Operand value (add_value or swap_value) + compare: Compare value (0 for ADD, compare_value for CAS) + queue_ptr: Queue pointer from device context + mask: Triton mask for valid elements + """ + state_ptr = queue_ptr.to(tl.pointer_type(tl.uint64)) + + # Load QueueState fields + items_ptr = tl.load(state_ptr + 0) + head_ptr = tl.load(state_ptr + 1) + tail_ptr = tl.load(state_ptr + 2) + + # Load size (at offset 32 bytes = 4 * uint64) + size_ptr = queue_ptr.to(tl.pointer_type(tl.int32)) + size = tl.load(size_ptr + 8) + + # Atomic increment head to reserve slot + head_ptr_typed = head_ptr.to(tl.pointer_type(tl.uint64)) + prev_head = tl.atomic_add(head_ptr_typed, 1, sem='relaxed', scope='sys') + + # Wait for slot to be free + size_u64 = size.to(tl.uint64) + tail_ptr_typed = tail_ptr.to(tl.pointer_type(tl.uint64)) + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + while prev_head >= size_u64 + current_tail: + current_tail = tl.atomic_add(tail_ptr_typed, 0, sem='acquire', scope='sys') + + # Calculate slot position + slot_idx = prev_head % size_u64 + + # WorkItem structure (48 bytes total): + # Header (32 bytes due to alignas(16)): + # offset 0: uint64_t dst_ptr + # offset 8: uint64_t src_ptr (result_ptr for atomics) + # offset 16: uint32_t size_bytes (WRITE LAST as ready flag) + # offset 20: uint16_t rank + # offset 22: uint8_t op_type + # offset 23: uint8_t reserved + # offset 24-31: padding (alignas(16) pads header to 32 bytes) + # Atomic fields (16 bytes): + # offset 32: uint64_t atomic_operand + # offset 40: uint64_t atomic_compare + WORK_ITEM_SIZE_BYTES = 48 # Header (32 with padding) + atomic fields (16) + + slot_offset_bytes = slot_idx * WORK_ITEM_SIZE_BYTES + + # Get pointer to this work item + items_ptr_u64 = items_ptr.to(tl.pointer_type(tl.uint64)) + slot_ptr_u64 = items_ptr_u64 + (slot_offset_bytes // 8).to(tl.int32) + + # Cast pointers to uint64 + dst_ptr_val = tl.cast(dst_ptr, tl.uint64) + result_ptr_val = tl.cast(result_ptr, tl.uint64) + operand_u64 = tl.cast(operand, tl.uint64) + compare_u64 = tl.cast(compare, tl.uint64) + + # Write WorkItem fields (except size which is written last as ready flag) + # Offset 0: dst_ptr (remote target) + tl.store(slot_ptr_u64 + 0, dst_ptr_val) + + # Offset 8: src_ptr (result buffer) + tl.store(slot_ptr_u64 + 1, result_ptr_val) + + # Offset 32: atomic_operand (offset 32 bytes = 4 * 8 bytes) + tl.store(slot_ptr_u64 + 4, operand_u64) + + # Offset 40: atomic_compare (offset 40 bytes = 5 * 8 bytes) + tl.store(slot_ptr_u64 + 5, compare_u64) + + # Offset 20 (bytes): Pack rank (16 bits) + op_type (8 bits) into 32 bits + # Same as regular RDMA operations + slot_ptr_u32 = slot_ptr_u64.to(tl.pointer_type(tl.uint32)) + metadata = (to_rank & 0xFFFF) | ((op_code & 0xFF) << 16) + tl.store(slot_ptr_u32 + 5, metadata) # offset 20 bytes = 5 * 4 bytes + + # Offset 16 (bytes) / 4 (uint32): size_bytes - WRITE LAST as ready flag + # For atomics, size is always 8 bytes + tl.store(slot_ptr_u32 + 4, tl.cast(8, tl.uint32)) # offset 16 bytes = 4 * 4 bytes + + return prev_head + + +__all__ = [ + "IrisRDMA", + "iris", + "put", + "get", + "atomic_add", + "atomic_cas", +] + diff --git a/iris/experimental/iris_rdma/python/bindings.cpp b/iris/experimental/iris_rdma/python/bindings.cpp new file mode 100644 index 00000000..1af97958 --- /dev/null +++ b/iris/experimental/iris_rdma/python/bindings.cpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +/****************************************************************************** + * Python Bindings for Iris RDMA Backend using PyBind11 + *****************************************************************************/ + +#include +#include +#include +#include + +#include "network_backend.hpp" +#include "queue_pair.hpp" +#include "torch_bootstrap.hpp" +#include "iris_manager.hpp" + +namespace py = pybind11; + +PYBIND11_MODULE(_iris_rdma_backend, m) { + m.doc() = + "Iris RDMA Backend: InfiniBand RDMA with PyTorch Integration"; + + // Expose NICVendor enum + py::enum_(m, "nic_vendor") + .value("NONE", iris::rdma::nic_vendor::NONE) + .value("IONIC", iris::rdma::nic_vendor::IONIC) + .value("BNXT", iris::rdma::nic_vendor::BNXT) + .value("MLX5", iris::rdma::nic_vendor::MLX5) + .export_values(); + + // Expose qp_info_t struct + py::class_(m, "qp_info_t") + .def(py::init<>()) + .def_readwrite("qp_num", &iris::rdma::qp_info_t::qp_num) + .def_readwrite("lkey", &iris::rdma::qp_info_t::lkey) + .def_readwrite("rkey", &iris::rdma::qp_info_t::rkey) + .def_readwrite("dst_rank", &iris::rdma::qp_info_t::dst_rank) + .def("__repr__", [](const iris::rdma::qp_info_t& info) { + return ""; + }); + + // Expose torch_bootstrap + py::class_>(m, + "torch_bootstrap") + .def(py::init([](py::object pg_obj) { + // Extract c10d::ProcessGroup from Python object + auto pg_ptr = + pg_obj.cast>(); + return std::make_shared(pg_ptr); + }), + py::arg("process_group")) + .def("get_rank", &iris::rdma::torch_bootstrap::get_rank) + .def("get_world_size", &iris::rdma::torch_bootstrap::get_world_size) + .def("barrier", &iris::rdma::torch_bootstrap::barrier); + + // Expose queue_pair (read-only access) + py::class_(m, "queue_pair") + .def("get_qp_num", &iris::queue_pair::get_qp_num) + .def("get_lkey", &iris::queue_pair::get_lkey) + .def("get_rkey", &iris::queue_pair::get_rkey) + .def("get_dst_rank", &iris::queue_pair::get_dst_rank) + .def("get_info", &iris::queue_pair::get_info) + .def("__repr__", [](const iris::queue_pair& qp) { + return ""; + }); + + // Expose network_backend + py::class_(m, "network_backend") + .def(py::init, const char*>(), + py::arg("bootstrap"), py::arg("device_name") = nullptr, + "Create network_backend with PyTorch bootstrap") + .def("init", &iris::network_backend::init, + "Initialize the network (setup QPs, transition to RTS)") + .def( + "register_memory", + [](iris::network_backend& self, py::object obj, size_t size = 0) { + void* ptr = nullptr; + size_t actual_size = size; + + // Check if it's an integer (raw pointer) + if (PyLong_Check(obj.ptr())) { + ptr = reinterpret_cast(PyLong_AsVoidPtr(obj.ptr())); + if (size == 0) { + throw std::runtime_error("Size must be specified for raw pointer"); + } + actual_size = size; + } + // Check if it's a PyTorch tensor + else if (THPVariable_Check(obj.ptr())) { + auto t = THPVariable_Unpack(obj.ptr()); + ptr = t.data_ptr(); + actual_size = t.numel() * t.element_size(); + } + else { + throw std::runtime_error("Expected a PyTorch tensor or integer address"); + } + + self.register_memory(ptr, actual_size); + }, + py::arg("obj"), py::arg("size") = 0, + "Register memory for RDMA (supports CPU pinned or GPU memory via GPUDirect)") + .def("get_qp", &iris::network_backend::get_qp, py::arg("dst_rank"), + py::return_value_policy::reference_internal, + "Get queue pair for destination rank") + .def("get_qp_info", &iris::network_backend::get_qp_info, py::arg("dst_rank"), + "Get QP info for destination rank") + .def("get_rank", &iris::network_backend::get_rank, "Get rank") + .def("get_world_size", &iris::network_backend::get_world_size, "Get world size") + .def("get_remote_heap_base", &iris::network_backend::get_remote_heap_base, + py::arg("rank"), + "Get remote heap base address for a rank") + .def("get_heap_base", &iris::network_backend::get_heap_base, + "Get local heap base address") + .def("get_heap_size", &iris::network_backend::get_heap_size, + "Get heap size in bytes") + .def("rdma_write", + [](iris::network_backend& self, int dst_rank, uint64_t local_addr, + uint64_t remote_addr, size_t size, uint64_t wr_id) { + return self.rdma_write(dst_rank, reinterpret_cast(local_addr), + remote_addr, size, wr_id); + }, + py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), + py::arg("size"), py::arg("wr_id") = 0, + "RDMA write to remote rank (local_addr is integer address)") + .def("rdma_read", + [](iris::network_backend& self, int dst_rank, uint64_t local_addr, + uint64_t remote_addr, size_t size, uint64_t wr_id) { + return self.rdma_read(dst_rank, reinterpret_cast(local_addr), + remote_addr, size, wr_id); + }, + py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), + py::arg("size"), py::arg("wr_id") = 0, + "RDMA read from remote rank (local_addr is integer address)") + .def("poll_cq", &iris::network_backend::poll_cq, + py::arg("dst_rank"), py::arg("max_completions") = 1, + "Poll completion queue for RDMA operations") + .def("__repr__", [](const iris::network_backend& backend) { + return ""; + }); + + py::class_(m, "rdma_proxy") + .def(py::init([](std::shared_ptr bootstrap, py::object heap_tensor, int queue_size) { + // Extract heap pointer from tensor + if (!THPVariable_Check(heap_tensor.ptr())) { + throw std::runtime_error("heap_tensor must be a PyTorch tensor"); + } + auto heap = THPVariable_Unpack(heap_tensor.ptr()); + void* heap_ptr = heap.data_ptr(); + size_t heap_size = heap.numel() * heap.element_size(); + + return new iris::rdma_proxy(bootstrap, heap_ptr, heap_size, queue_size); + }), + py::arg("bootstrap"), py::arg("heap_tensor"), py::arg("queue_size") = 512, + "Create rdma_proxy with network_backend + Queue + Proxy Thread") + .def("start_proxy_thread", &iris::rdma_proxy::start_proxy_thread, + "Start proxy thread that processes RDMA operations from queue") + .def("stop_proxy_thread", &iris::rdma_proxy::stop_proxy_thread, + "Stop proxy thread") + .def("get_queue_ptr", + [](iris::rdma_proxy& self) { + return reinterpret_cast(self.get_queue_ptr()); + }, + "Get queue pointer for Triton kernels") + .def("get_heap_base", &iris::rdma_proxy::get_heap_base, + "Get local heap base address") + .def("get_remote_heap_base", &iris::rdma_proxy::get_remote_heap_base, + py::arg("rank"), + "Get remote heap base address for a rank") + .def("get_rank", &iris::rdma_proxy::get_rank, "Get rank") + .def("get_world_size", &iris::rdma_proxy::get_world_size, "Get world size") + .def("is_queue_empty", &iris::rdma_proxy::is_queue_empty, + "Check if queue is empty (all work items processed)") + .def("rdma_write", + [](iris::rdma_proxy& self, int dst_rank, uint64_t local_addr, + uint64_t remote_addr, size_t size, uint64_t wr_id) { + auto backend = self.get_backend(); + return backend->rdma_write(dst_rank, reinterpret_cast(local_addr), + remote_addr, size, wr_id); + }, + py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), + py::arg("size"), py::arg("wr_id") = 0, + "RDMA write to remote rank (local_addr is integer address)") + .def("rdma_read", + [](iris::rdma_proxy& self, int dst_rank, uint64_t local_addr, + uint64_t remote_addr, size_t size, uint64_t wr_id) { + auto backend = self.get_backend(); + return backend->rdma_read(dst_rank, reinterpret_cast(local_addr), + remote_addr, size, wr_id); + }, + py::arg("dst_rank"), py::arg("local_addr"), py::arg("remote_addr"), + py::arg("size"), py::arg("wr_id") = 0, + "RDMA read from remote rank (local_addr is integer address)") + .def("poll_cq", + [](iris::rdma_proxy& self, int dst_rank, int max_completions) { + auto backend = self.get_backend(); + return backend->poll_cq(dst_rank, max_completions); + }, + py::arg("dst_rank"), py::arg("max_completions") = 1, + "Poll completion queue for RDMA operations") + .def("__repr__", [](const iris::rdma_proxy& mgr) { + return ""; + }); +} + diff --git a/iris/experimental/iris_rdma/src/ibv_utils.hpp b/iris/experimental/iris_rdma/src/ibv_utils.hpp new file mode 100644 index 00000000..7a9632f1 --- /dev/null +++ b/iris/experimental/iris_rdma/src/ibv_utils.hpp @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include +#include +#include + +#include "logging.hpp" + +namespace iris { +namespace rdma { + +// Error checking macros +#define CHECK_ZERO(expr, msg) \ + do { \ + int ret = (expr); \ + if (ret != 0) { \ + LOG_ERROR("%s failed with code %d: %s", msg, ret, strerror(ret)); \ + abort(); \ + } \ + } while (0) + +#define CHECK_NNULL(ptr, msg) \ + do { \ + if ((ptr) == nullptr) { \ + LOG_ERROR("%s returned NULL", msg); \ + abort(); \ + } \ + } while (0) + +// Vendor detection +enum class nic_vendor { NONE, IONIC, BNXT, MLX5 }; + +// QP destination info for connection +struct qp_dest_info_t { + int lid; + int qpn; + int psn; + union ibv_gid gid; +}; + +// QP metadata exposed to Python +struct qp_info_t { + uint32_t qp_num; + uint32_t lkey; + uint32_t rkey; + int dst_rank; +}; + +// Helper functions +inline void dump_ibv_device(struct ibv_device* device) { + LOG_DEBUG("IBV Device: %s", ibv_get_device_name(device)); +} + +inline void dump_ibv_context(struct ibv_context* ctx) { + LOG_DEBUG("IBV Context: device=%s", ctx->device->name); +} + +inline void dump_ibv_pd(struct ibv_pd* pd) { + LOG_DEBUG("IBV PD: handle=%u", pd->handle); +} + +inline void dump_ibv_port_attr(struct ibv_port_attr* attr) { + LOG_DEBUG("Port Attr: state=%d, lid=%d, link_layer=%d, active_mtu=%d", + attr->state, attr->lid, attr->link_layer, attr->active_mtu); +} + +inline int ibv_mtu_to_int(enum ibv_mtu mtu) { + switch (mtu) { + case IBV_MTU_256: + return 256; + case IBV_MTU_512: + return 512; + case IBV_MTU_1024: + return 1024; + case IBV_MTU_2048: + return 2048; + case IBV_MTU_4096: + return 4096; + default: + fprintf(stderr, "[ERROR] Invalid ibv_mtu\n"); + return 0; + } +} + +} // namespace rdma +} // namespace iris + diff --git a/iris/experimental/iris_rdma/src/iris_manager.hpp b/iris/experimental/iris_rdma/src/iris_manager.hpp new file mode 100644 index 00000000..ac6c0d8d --- /dev/null +++ b/iris/experimental/iris_rdma/src/iris_manager.hpp @@ -0,0 +1,471 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +/** + * @file iris_manager.hpp + * @brief Complete Iris RDMA integration: Network + Queue + Proxy Thread + * + * Combines: + * - NetworkBackend (InfiniBand RDMA) + * - TritonDeviceQueue (GPU->CPU queue) + * - Proxy Thread (processes RDMA operations from queue) + */ + +#pragma once + +#include +#include +#include +#include "network_backend.hpp" +#include "queue.hpp" + +namespace iris { + +/** + * @brief Get maximum number of polling attempts from environment variable + * @return Max attempts (default 100, configurable via IRIS_RDMA_POLL_MAX_ATTEMPTS) + */ +inline int get_max_poll_attempts() { + static int max_attempts = []() { + const char* env = std::getenv("IRIS_RDMA_POLL_MAX_ATTEMPTS"); + if (env) { + int val = std::atoi(env); + if (val > 0) return val; + } + return 100; // Default + }(); + return max_attempts; +} + +/** + * @brief Complete Iris RDMA Proxy + * + * Integration of network_backend + TritonDeviceQueue + Proxy Thread + * Provides a unified interface for Triton kernels to perform RDMA operations + */ +class rdma_proxy { + public: + /** + * @brief Constructor + * @param bootstrap PyTorch bootstrap for distributed communication + * @param heap_base Pointer to symmetric heap + * @param heap_size Size of symmetric heap in bytes + * @param queue_size Queue capacity (default: 512) + */ + rdma_proxy(std::shared_ptr bootstrap, + void* heap_base, + size_t heap_size, + int queue_size = 512) + : heap_base_((uint64_t)heap_base), + heap_size_(heap_size), + running_(false) { + + // Step 1: Create network_backend and initialize + backend_ = std::make_unique(bootstrap); + backend_->init(); + + // Step 2: Register symmetric heap (collective operation) + backend_->register_memory(heap_base, heap_size); + + // Step 3: Create CPU-GPU queue + queue_ = std::make_unique(queue_size); + } + + ~rdma_proxy() { + if (running_) { + stop_proxy_thread(); + } + } + + /** + * @brief Start the proxy thread that processes RDMA operations + */ + void start_proxy_thread() { + if (running_) return; + running_ = true; + proxy_thread_ = std::thread(&rdma_proxy::proxy_loop, this); + } + + /** + * @brief Stop the proxy thread + */ + void stop_proxy_thread() { + running_ = false; + if (proxy_thread_.joinable()) { + proxy_thread_.join(); + } + } + + /** + * @brief Get the queue state pointer (for passing to Triton kernels) + */ + rdma::queue_state_t* get_queue_ptr() { + return queue_->get_queue_ptr(); + } + + /** + * @brief Get heap base address + */ + uint64_t get_heap_base() { return heap_base_; } + + /** + * @brief Get the network_backend (for direct RDMA operations) + */ + network_backend* get_backend() { return backend_.get(); } + + /** + * @brief Get remote heap base for a given rank + */ + uint64_t get_remote_heap_base(int rank) { + return backend_->get_remote_heap_base(rank); + } + + /** + * @brief Get rank + */ + int get_rank() const { return backend_->get_rank(); } + + /** + * @brief Get world size + */ + int get_world_size() const { return backend_->get_world_size(); } + + /** + * @brief Check if queue is empty (all work processed) + */ + bool is_queue_empty() const { return queue_->is_empty(); } + + private: + /** + * @brief Main proxy loop - processes RDMA operations from GPU queue + */ + void proxy_loop() { + rdma::work_item_t item; + + while (running_) { + // Poll for work from GPU queue + if (queue_->poll(item)) { + process_work_item(item); + } + } + } + + /** + * @brief Debug helper to print work item data + */ + void debug_print_work_item(const rdma::work_item_t& item) { + static bool debug_enabled = (getenv("IRIS_DEBUG_DATA") != nullptr); + if (!debug_enabled || item.header.size_bytes < 4) return; + + // Extract info from work item + auto op_type = static_cast(item.header.op_type); + const char* op_name = (op_type == rdma::operation_type::PUT) ? "PUT" : + (op_type == rdma::operation_type::GET) ? "GET" : "OP"; + int dst_rank = item.header.rank; + uint64_t src_ptr = item.header.src_ptr; + uint64_t dst_ptr = item.header.dst_ptr; + size_t size = item.header.size_bytes; + void* data = (void*)src_ptr; + + static const char* dtype_env = getenv("IRIS_DTYPE"); + bool is_bf16 = (dtype_env && strcmp(dtype_env, "bfloat16") == 0); + bool is_fp16 = (dtype_env && strcmp(dtype_env, "float16") == 0); + bool is_fp32 = (!dtype_env || strcmp(dtype_env, "float32") == 0); + + if (is_bf16 || is_fp16) { + // 2-byte types + int elem_count = std::min((int)(size / 2), 10); + uint16_t* data_ptr = (uint16_t*)data; + LOG_DATA_DEBUG("[%s] rank=%d dst=%d size=%zu (bf16) src=%lx dst=%lx: first values", + op_name, backend_->get_rank(), dst_rank, size, src_ptr, dst_ptr); + for (int i = 0; i < elem_count; i++) { + uint32_t fp32_bits = ((uint32_t)data_ptr[i]) << 16; + float value = *reinterpret_cast(&fp32_bits); + fprintf(stderr, "%.1f ", value); + } + fprintf(stderr, "\n"); + } else if (is_fp32) { + // 4-byte types + int elem_count = std::min((int)(size / 4), 10); + float* float_ptr = (float*)data; + LOG_DATA_DEBUG("[%s] rank=%d dst=%d size=%zu (fp32) src=%lx dst=%lx: first values", + op_name, backend_->get_rank(), dst_rank, size, src_ptr, dst_ptr); + for (int i = 0; i < elem_count; i++) { + fprintf(stderr, "%.1f ", float_ptr[i]); + } + fprintf(stderr, "\n"); + } + } + + /** + * @brief Convert operation type to string + */ + const char* op_type_to_string(uint8_t op_type) { + switch (static_cast(op_type)) { + case rdma::operation_type::NOP: return "NOP"; + case rdma::operation_type::PUT: return "PUT"; + case rdma::operation_type::GET: return "GET"; + case rdma::operation_type::FLUSH: return "FLUSH"; + case rdma::operation_type::ATOMIC_ADD: return "ATOMIC_ADD"; + case rdma::operation_type::ATOMIC_EXCH: return "ATOMIC_EXCH"; + case rdma::operation_type::ATOMIC_CAS: return "ATOMIC_CAS"; + default: return "UNKNOWN"; + } + } + + /** + * @brief Dump raw work item bytes for debugging + */ + void dump_work_item_raw(const rdma::work_item_t& item) { + if (!iris::rdma::is_debug_data_enabled()) return; + + const uint8_t* bytes = reinterpret_cast(&item); + fprintf(stderr, "[DEBUG-DATA] Raw WorkItem (48 bytes):\n"); + fprintf(stderr, "[DEBUG-DATA] Header (32 bytes with alignas(16) padding):\n"); + fprintf(stderr, "[DEBUG-DATA] [0-7] dst_ptr: 0x%016lx\n", item.header.dst_ptr); + fprintf(stderr, "[DEBUG-DATA] [8-15] src_ptr: 0x%016lx\n", item.header.src_ptr); + fprintf(stderr, "[DEBUG-DATA] [16-19] size_bytes: %u\n", item.header.size_bytes); + fprintf(stderr, "[DEBUG-DATA] [20-21] rank: %u\n", item.header.rank); + fprintf(stderr, "[DEBUG-DATA] [22] op_type: %u (%s)\n", + item.header.op_type, op_type_to_string(item.header.op_type)); + fprintf(stderr, "[DEBUG-DATA] [23] reserved: %u\n", item.header.reserved); + fprintf(stderr, "[DEBUG-DATA] [24-31] padding (alignas)\n"); + fprintf(stderr, "[DEBUG-DATA] Atomic fields (16 bytes):\n"); + fprintf(stderr, "[DEBUG-DATA] [32-39] operand: 0x%016lx (%lu)\n", + item.atomic_operand, item.atomic_operand); + fprintf(stderr, "[DEBUG-DATA] [40-47] compare: 0x%016lx (%lu)\n", + item.atomic_compare, item.atomic_compare); + fprintf(stderr, "[DEBUG-DATA] Raw bytes: "); + for (int i = 0; i < 48; i++) { + fprintf(stderr, "%02x ", bytes[i]); + if ((i + 1) % 8 == 0) fprintf(stderr, " "); + } + fprintf(stderr, "\n"); + fflush(stderr); + } + + /** + * @brief Process a single work item from the queue + */ + void process_work_item(const rdma::work_item_t& item) { + auto op_type = static_cast(item.header.op_type); + int dst_rank = item.header.rank; + + // Dump raw packet for atomic operations + if (op_type == rdma::operation_type::ATOMIC_ADD || + op_type == rdma::operation_type::ATOMIC_EXCH || + op_type == rdma::operation_type::ATOMIC_CAS) { + dump_work_item_raw(item); + } + + // Get addresses from queue metadata + uint64_t src_ptr = item.header.src_ptr; // Pointer/offset in registered heap + uint64_t dst_ptr = item.header.dst_ptr; // Remote destination + size_t size = item.header.size_bytes; + + switch (op_type) { + case rdma::operation_type::PUT: { + // RDMA Write: Data is already in the registered heap at src_ptr + // No memcpy needed - just RDMA directly from heap! + void* local_addr = (void*)src_ptr; + + LOG_DEBUG("PUT: rank=%d src=%lx dst=%lx size=%zu", + dst_rank, src_ptr, dst_ptr, size); + + debug_print_work_item(item); + + int ret = backend_->rdma_write(dst_rank, local_addr, dst_ptr, size); + if (ret != 0) { + LOG_ERROR("RDMA write failed: dst=%d size=%lu", dst_rank, size); + } else { + // Poll for completion + int n = 0; + int max_attempts = get_max_poll_attempts(); + for (int attempt = 0; attempt < max_attempts; attempt++) { + n = backend_->poll_cq(dst_rank, 1); + if (n > 0) break; + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + LOG_DEBUG("Warning: PUT completion not polled (may be OK if async)"); + } + } + + // Signal completion + queue_->pop(); + break; + } + + case rdma::operation_type::GET: { + // RDMA Read: Read from remote into local + // NOTE: WorkItem field naming is confusing for GET! + // WorkItem.dst_ptr contains REMOTE source (translated by Triton kernel) + // WorkItem.src_ptr contains LOCAL destination + void* local_addr = (void*)src_ptr; // src_ptr field has local dest + uint64_t remote_addr = dst_ptr; // dst_ptr field has remote source + + LOG_DEBUG("GET: rank=%d remote_src=%lx local_dst=%lx size=%zu", + dst_rank, remote_addr, local_addr, size); + + int ret = backend_->rdma_read(dst_rank, local_addr, remote_addr, size); + if (ret != 0) { + LOG_ERROR("RDMA read failed: dst=%d size=%lu", dst_rank, size); + } else { + // Poll for completion + int n = 0; + int max_attempts = get_max_poll_attempts(); + for (int attempt = 0; attempt < max_attempts; attempt++) { + n = backend_->poll_cq(dst_rank, 1); + if (n > 0) break; + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + LOG_DEBUG("Warning: GET completion not polled (may be OK if async)"); + } + } + + // Signal completion - GPU can now read from heap at src_ptr + queue_->pop(); + break; + } + + case rdma::operation_type::FLUSH: { + // Flush all pending operations for this rank + LOG_DEBUG("FLUSH: rank=%d", dst_rank); + + int total = 0; + int n; + do { + n = backend_->poll_cq(dst_rank, 16); + if (n > 0) total += n; + } while (n > 0); + + queue_->pop(); + break; + } + + case rdma::operation_type::ATOMIC_ADD: { + // Atomic add: fetch-and-add operation + // src_ptr = local result buffer, dst_ptr = remote target, atomic_operand = value to add + void* result_addr = (void*)src_ptr; + uint64_t operand = item.atomic_operand; + + LOG_DEBUG("ATOMIC_ADD: rank=%d dst=%lx operand=%lu result_buf=%lx size=%zu", + dst_rank, dst_ptr, operand, src_ptr, size); + + // Local atomics should be handled directly by the GPU kernel, not offloaded to CPU + if (dst_rank == backend_->get_rank()) { + LOG_ERROR("ERROR: Local atomic operation detected (rank %d -> rank %d). " + "Local atomics should be handled directly in the Triton kernel, " + "not offloaded through the RDMA queue!", + backend_->get_rank(), dst_rank); + queue_->pop(); + break; + } + + // Remote atomic - use RDMA + int ret = backend_->rdma_atomic_fetch_add(dst_rank, result_addr, dst_ptr, operand, size); + if (ret != 0) { + LOG_ERROR("RDMA atomic add failed: dst=%d size=%lu ret=%d", dst_rank, size, ret); + } else { + LOG_DEBUG("RDMA atomic add posted successfully, polling for completion..."); + // Poll for completion + int max_attempts = get_max_poll_attempts(); + int n = 0; + for (int attempt = 0; attempt < max_attempts; attempt++) { + n = backend_->poll_cq(dst_rank, 1); + if (n > 0) { + LOG_DEBUG("ATOMIC_ADD completed after %d attempts, completions=%d", attempt+1, n); + break; + } + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + LOG_ERROR("Warning: ATOMIC_ADD completion not polled after %d attempts!", max_attempts); + } + } + + queue_->pop(); + break; + } + + case rdma::operation_type::ATOMIC_EXCH: { + // Atomic exchange: swap operation + // src_ptr = local result buffer, dst_ptr = remote target, atomic_operand = new value + void* result_addr = (void*)src_ptr; + uint64_t new_value = item.atomic_operand; + + LOG_DEBUG("ATOMIC_EXCH: rank=%d dst=%lx new_val=%lu result_buf=%lx size=%zu", + dst_rank, dst_ptr, new_value, src_ptr, size); + + int ret = backend_->rdma_atomic_exchange(dst_rank, result_addr, dst_ptr, new_value, size); + if (ret != 0) { + LOG_ERROR("RDMA atomic exchange failed: dst=%d size=%lu", dst_rank, size); + } else { + // Poll for completion + int n = 0; + int max_attempts = get_max_poll_attempts(); + for (int attempt = 0; attempt < max_attempts; attempt++) { + n = backend_->poll_cq(dst_rank, 1); + if (n > 0) break; + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + LOG_DEBUG("Warning: ATOMIC_EXCH completion not polled (may be OK if async)"); + } + } + + queue_->pop(); + break; + } + + case rdma::operation_type::ATOMIC_CAS: { + // Atomic compare-and-swap + // src_ptr = local result buffer, dst_ptr = remote target, + // atomic_compare = expected value, atomic_operand = new value + void* result_addr = (void*)src_ptr; + uint64_t compare = item.atomic_compare; + uint64_t swap = item.atomic_operand; + + LOG_DEBUG("ATOMIC_CAS: rank=%d dst=%lx compare=%lu swap=%lu result_buf=%lx size=%zu", + dst_rank, dst_ptr, compare, swap, src_ptr, size); + + int ret = backend_->rdma_atomic_compare_swap(dst_rank, result_addr, dst_ptr, compare, swap, size); + if (ret != 0) { + LOG_ERROR("RDMA atomic CAS failed: dst=%d size=%lu", dst_rank, size); + } else { + // Poll for completion + int n = 0; + int max_attempts = get_max_poll_attempts(); + for (int attempt = 0; attempt < max_attempts; attempt++) { + n = backend_->poll_cq(dst_rank, 1); + if (n > 0) break; + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } + if (n <= 0) { + LOG_DEBUG("Warning: ATOMIC_CAS completion not polled (may be OK if async)"); + } + } + + queue_->pop(); + break; + } + + default: + LOG_ERROR("Unknown operation type: %d", item.header.op_type); + queue_->pop(); + } + } + + std::unique_ptr backend_; + std::unique_ptr queue_; + + uint64_t heap_base_; + size_t heap_size_; + + std::atomic running_; + std::thread proxy_thread_; +}; + +} // namespace iris + diff --git a/iris/experimental/iris_rdma/src/logging.hpp b/iris/experimental/iris_rdma/src/logging.hpp new file mode 100644 index 00000000..d21eaec6 --- /dev/null +++ b/iris/experimental/iris_rdma/src/logging.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include + +namespace iris { +namespace rdma { + +// Log levels +enum class log_level { + DEBUG = 0, + INFO = 1, + WARN = 2, + ERROR = 3, + NONE = 4 +}; + +// Global log level (can be set via environment variable) +inline log_level get_log_level() { + static log_level level = []() { + const char* env = std::getenv("IRIS_LOG_LEVEL"); + if (!env) return log_level::INFO; + + if (strcmp(env, "DEBUG") == 0) return log_level::DEBUG; + if (strcmp(env, "INFO") == 0) return log_level::INFO; + if (strcmp(env, "WARN") == 0) return log_level::WARN; + if (strcmp(env, "ERROR") == 0) return log_level::ERROR; + if (strcmp(env, "NONE") == 0) return log_level::NONE; + + return log_level::INFO; + }(); + return level; +} + +// Check if debug data printing is enabled (separate from log level) +inline bool is_debug_data_enabled() { + static bool enabled = (std::getenv("IRIS_DEBUG_DATA") != nullptr); + return enabled; +} + +// Internal logging function +inline void log_message(log_level level, const char* level_str, const char* fmt, ...) { + if (level < get_log_level()) return; + + // Get timestamp + time_t now = time(nullptr); + struct tm* tm_info = localtime(&now); + char time_buf[64]; + strftime(time_buf, sizeof(time_buf), "%Y-%m-%d %H:%M:%S", tm_info); + + // Print level and timestamp + fprintf(stderr, "[%s] [%s] ", time_buf, level_str); + + // Print message + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + + fprintf(stderr, "\n"); + fflush(stderr); +} + +// Internal logging function with rank +inline void log_message_rank(int rank, log_level level, const char* level_str, const char* fmt, ...) { + if (level < get_log_level()) return; + + // Get timestamp + time_t now = time(nullptr); + struct tm* tm_info = localtime(&now); + char time_buf[64]; + strftime(time_buf, sizeof(time_buf), "%Y-%m-%d %H:%M:%S", tm_info); + + // Print level, timestamp, and rank + fprintf(stderr, "[%s] [%s] [RANK %d] ", time_buf, level_str, rank); + + // Print message + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + + fprintf(stderr, "\n"); + fflush(stderr); +} + +} // namespace rdma +} // namespace iris + +// Logging macros - easy to replace with real logging library later +#define LOG_DEBUG(fmt, ...) \ + iris::rdma::log_message(iris::rdma::log_level::DEBUG, "DEBUG", fmt, ##__VA_ARGS__) + +#define LOG_INFO(fmt, ...) \ + iris::rdma::log_message(iris::rdma::log_level::INFO, "INFO", fmt, ##__VA_ARGS__) + +#define LOG_WARN(fmt, ...) \ + iris::rdma::log_message(iris::rdma::log_level::WARN, "WARN", fmt, ##__VA_ARGS__) + +#define LOG_ERROR(fmt, ...) \ + iris::rdma::log_message(iris::rdma::log_level::ERROR, "ERROR", fmt, ##__VA_ARGS__) + +// Rank-aware logging macros +#define LOG_DEBUG_RANK(rank, fmt, ...) \ + iris::rdma::log_message_rank(rank, iris::rdma::log_level::DEBUG, "DEBUG", fmt, ##__VA_ARGS__) + +#define LOG_INFO_RANK(rank, fmt, ...) \ + iris::rdma::log_message_rank(rank, iris::rdma::log_level::INFO, "INFO", fmt, ##__VA_ARGS__) + +#define LOG_WARN_RANK(rank, fmt, ...) \ + iris::rdma::log_message_rank(rank, iris::rdma::log_level::WARN, "WARN", fmt, ##__VA_ARGS__) + +#define LOG_ERROR_RANK(rank, fmt, ...) \ + iris::rdma::log_message_rank(rank, iris::rdma::log_level::ERROR, "ERROR", fmt, ##__VA_ARGS__) + +// For data debugging (separate from regular logging) +#define LOG_DATA_DEBUG(fmt, ...) \ + do { \ + if (iris::rdma::is_debug_data_enabled()) { \ + fprintf(stderr, "[DEBUG-DATA] " fmt "\n", ##__VA_ARGS__); \ + fflush(stderr); \ + } \ + } while (0) + diff --git a/iris/experimental/iris_rdma/src/network_backend.hpp b/iris/experimental/iris_rdma/src/network_backend.hpp new file mode 100644 index 00000000..0405665e --- /dev/null +++ b/iris/experimental/iris_rdma/src/network_backend.hpp @@ -0,0 +1,838 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ibv_utils.hpp" +#include "queue_pair.hpp" +#include "torch_bootstrap.hpp" + +// Vendor-specific headers +#ifdef HAVE_MLX5 +#include +#endif + +#ifdef HAVE_BNXT +#include +#endif + +namespace iris { + +/** + * @brief Main network backend for InfiniBand setup + * + * Handles: + * - Device detection and initialization + * - Protection domain creation + * - Queue pair creation and state transitions + * - Memory registration + * - QP connection info exchange + */ +class network_backend { + public: + /** + * @brief Constructor + * @param bootstrap PyTorch bootstrap for cross-rank communication + * @param device_name Optional device name (NULL for auto-detect) + */ + network_backend(std::shared_ptr bootstrap, + const char* device_name = nullptr) + : bootstrap_(bootstrap), + requested_dev_(device_name), + context_(nullptr), + pd_orig_(nullptr), + vendor_(rdma::nic_vendor::NONE), + port_(1), + gid_index_(0), + heap_mr_(nullptr), + heap_base_(0), + heap_size_(0), + mlx5dv_handle_(nullptr), + bnxtdv_handle_(nullptr) { + if (!bootstrap_) { + throw std::runtime_error("Bootstrap cannot be null"); + } + rank_ = bootstrap_->get_rank(); + world_size_ = bootstrap_->get_world_size(); + LOG_INFO("network_backend created: rank=%d, world_size=%d", rank_, world_size_); + } + + /** + * @brief Destructor - cleanup InfiniBand resources + */ + ~network_backend() { + LOG_DEBUG("network_backend cleanup started"); + + qps_.clear(); + + for (auto* cq : cqs_) { + if (cq) { + ibv_destroy_cq(cq); + } + } + cqs_.clear(); + + if (heap_mr_) { + ibv_dereg_mr(heap_mr_); + heap_mr_ = nullptr; + } + + if (pd_orig_) { + ibv_dealloc_pd(pd_orig_); + pd_orig_ = nullptr; + } + + if (context_) { + ibv_close_device(context_); + context_ = nullptr; + } + + if (mlx5dv_handle_) { + dlclose(mlx5dv_handle_); + mlx5dv_handle_ = nullptr; + } + + if (bnxtdv_handle_) { + dlclose(bnxtdv_handle_); + bnxtdv_handle_ = nullptr; + } + + LOG_DEBUG("NetworkBackend cleanup completed"); + } + + /** + * @brief Initialize the network (setup QPs, transition to RTS) + */ + void init() { + LOG_INFO("network_backend::init() started"); + + autodetect_dv_libs(); + open_ib_device(); + create_queues(); + exchange_qp_dest_info(); + modify_qps_reset_to_init(); + modify_qps_init_to_rtr(); + modify_qps_rtr_to_rts(); + bootstrap_->barrier(); + + LOG_INFO("network_backend::init() completed"); + } + + /** + * @brief Register memory for RDMA + * @param ptr Pointer to memory region + * @param size Size in bytes + */ + void register_memory(void* ptr, size_t size) { + LOG_INFO("Registering memory: ptr=%p, size=%zu", ptr, size); + + int access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC; + + heap_mr_ = ibv_reg_mr(pd_orig_, ptr, size, access); + if (heap_mr_ == nullptr) { + int err = errno; + fprintf(stderr, "[ERROR] ibv_reg_mr returned NULL for ptr=%p, size=%zu, errno=%d (%s)\n", + ptr, size, err, strerror(err)); + char error_msg[256]; + snprintf(error_msg, sizeof(error_msg), + "ibv_reg_mr failed with errno %d (%s) - GPUDirect RDMA may not be enabled", + err, strerror(err)); + throw std::runtime_error(error_msg); + } + + // Store local heap base + heap_base_ = reinterpret_cast(ptr); + heap_size_ = size; + + // Exchange remote keys + rkeys_.resize(world_size_); + std::vector all_rkeys(world_size_); + all_rkeys[rank_] = heap_mr_->rkey; + bootstrap_->all_gather(all_rkeys.data(), sizeof(uint32_t)); + for (int i = 0; i < world_size_; i++) { + rkeys_[i] = all_rkeys[i]; + } + + // Exchange heap base addresses (collective operation) + remote_heap_bases_.resize(world_size_); + std::vector all_heap_bases(world_size_); + all_heap_bases[rank_] = heap_base_; + bootstrap_->all_gather(all_heap_bases.data(), sizeof(uint64_t)); + for (int i = 0; i < world_size_; i++) { + remote_heap_bases_[i] = all_heap_bases[i]; + } + + // Update QPs with lkey and rkey + uint32_t lkey = heap_mr_->lkey; + for (int i = 0; i < world_size_; i++) { + if (i < qps_.size() && qps_[i]) { + qps_[i]->set_lkey(lkey); + qps_[i]->set_rkey(rkeys_[i]); + } + } + + LOG_INFO("Memory registered: lkey=%u, rkey=%u, heap_base=%p", + lkey, heap_mr_->rkey, ptr); + } + + /** + * @brief Get queue pair for destination rank + * @param dst_rank Destination rank + * @return Pointer to QueuePair object + */ + queue_pair* get_qp(int dst_rank) { + if (dst_rank >= 0 && dst_rank < qps_.size()) { + return qps_[dst_rank].get(); + } + return nullptr; + } + + /** + * @brief Get QP info for Python + * @param dst_rank Destination rank + * @return QPInfo structure + */ + rdma::qp_info_t get_qp_info(int dst_rank) { + queue_pair* qp = get_qp(dst_rank); + if (qp) { + return qp->get_info(); + } + return rdma::qp_info_t{0, 0, 0, dst_rank}; + } + + + + + /** + * @brief Get rank + */ + int get_rank() const { return rank_; } + + /** + * @brief Get world size + */ + int get_world_size() const { return world_size_; } + + /** + * @brief Get remote heap base address for a rank + * @param rank Remote rank + * @return Remote heap base address (0 if not registered) + */ + uint64_t get_remote_heap_base(int rank) const { + if (rank >= 0 && rank < remote_heap_bases_.size()) { + return remote_heap_bases_[rank]; + } + return 0; + } + + /** + * @brief Get local heap base address + * @return Local heap base address (0 if not registered) + */ + uint64_t get_heap_base() const { return heap_base_; } + + /** + * @brief Get heap size + * @return Heap size in bytes (0 if not registered) + */ + size_t get_heap_size() const { return heap_size_; } + + /** + * @brief RDMA Write operation + * @param dst_rank Destination rank + * @param local_addr Local buffer address + * @param remote_addr Remote buffer address + * @param size Size in bytes + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdma_write(int dst_rank, void* local_addr, uint64_t remote_addr, + size_t size, uint64_t wr_id = 0) { + queue_pair* qp = get_qp(dst_rank); + if (!qp) { + return -1; + } + + struct ibv_sge sge; + sge.addr = (uintptr_t)local_addr; + sge.length = size; + sge.lkey = qp->get_lkey(); + + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = wr_id; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_WRITE; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.rdma.remote_addr = remote_addr; + wr.wr.rdma.rkey = qp->get_rkey(); + + struct ibv_send_wr* bad_wr; + int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); + + LOG_DEBUG("RDMA Write to rank %d: local=%p remote=%lx size=%zu ret=%d", + dst_rank, local_addr, remote_addr, size, ret); + + return ret; + } + + /** + * @brief RDMA Read operation + * @param dst_rank Destination rank + * @param local_addr Local buffer address + * @param remote_addr Remote buffer address + * @param size Size in bytes + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdma_read(int dst_rank, void* local_addr, uint64_t remote_addr, + size_t size, uint64_t wr_id = 0) { + queue_pair* qp = get_qp(dst_rank); + if (!qp) { + return -1; + } + + struct ibv_sge sge; + sge.addr = (uintptr_t)local_addr; + sge.length = size; + sge.lkey = qp->get_lkey(); + + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = wr_id; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_RDMA_READ; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.rdma.remote_addr = remote_addr; + wr.wr.rdma.rkey = qp->get_rkey(); + + struct ibv_send_wr* bad_wr; + int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); + + LOG_DEBUG("RDMA Read from rank %d: local=%p remote=%lx size=%zu ret=%d", + dst_rank, local_addr, remote_addr, size, ret); + + return ret; + } + + /** + * @brief RDMA atomic fetch-and-add operation + * @param dst_rank Destination rank + * @param result_addr Local buffer to store the original value + * @param remote_addr Remote address to perform atomic add on + * @param add_value Value to add + * @param size Size in bytes (must be 4 or 8) + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdma_atomic_fetch_add(int dst_rank, void* result_addr, uint64_t remote_addr, + uint64_t add_value, size_t size, uint64_t wr_id = 0) { + queue_pair* qp = get_qp(dst_rank); + if (!qp) { + return -1; + } + + if (size != 4 && size != 8) { + LOG_ERROR("Atomic operations only support 4 or 8 byte sizes, got %zu", size); + return -1; + } + + struct ibv_sge sge; + sge.addr = (uintptr_t)result_addr; + sge.length = size; + sge.lkey = qp->get_lkey(); + + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = wr_id; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.atomic.remote_addr = remote_addr; + wr.wr.atomic.rkey = qp->get_rkey(); + wr.wr.atomic.compare_add = add_value; // Value to add + + struct ibv_send_wr* bad_wr = nullptr; + int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); + + LOG_DEBUG("RDMA Atomic Fetch-Add to rank %d: result=%p remote=%lx add=%lu size=%zu ret=%d", + dst_rank, result_addr, remote_addr, add_value, size, ret); + + return ret; + } + + /** + * @brief RDMA atomic exchange (swap) operation + * @param dst_rank Destination rank + * @param result_addr Local buffer to store the original value + * @param remote_addr Remote address to exchange + * @param new_value New value to write + * @param size Size in bytes (must be 4 or 8) + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdma_atomic_exchange(int dst_rank, void* result_addr, uint64_t remote_addr, + uint64_t new_value, size_t size, uint64_t wr_id = 0) { + queue_pair* qp = get_qp(dst_rank); + if (!qp) { + return -1; + } + + if (size != 4 && size != 8) { + LOG_ERROR("Atomic operations only support 4 or 8 byte sizes, got %zu", size); + return -1; + } + + // For exchange, we need a staging buffer for the new value + // ibverbs doesn't have a direct exchange, so we use CAS in a loop + // But for simplicity, we can use MLX5 extended atomics if available + // For now, we'll return an error and note this needs vendor-specific support + LOG_ERROR("RDMA atomic exchange not yet implemented - needs vendor-specific support"); + return -1; + } + + /** + * @brief RDMA atomic compare-and-swap operation + * @param dst_rank Destination rank + * @param result_addr Local buffer to store the original value + * @param remote_addr Remote address to perform CAS on + * @param compare_value Expected value + * @param swap_value Value to swap in if comparison succeeds + * @param size Size in bytes (must be 4 or 8) + * @param wr_id Work request ID (for completion tracking) + * @return 0 on success, non-zero on error + */ + int rdma_atomic_compare_swap(int dst_rank, void* result_addr, uint64_t remote_addr, + uint64_t compare_value, uint64_t swap_value, + size_t size, uint64_t wr_id = 0) { + queue_pair* qp = get_qp(dst_rank); + if (!qp) { + return -1; + } + + if (size != 4 && size != 8) { + LOG_ERROR("Atomic operations only support 4 or 8 byte sizes, got %zu", size); + return -1; + } + + struct ibv_sge sge; + sge.addr = (uintptr_t)result_addr; + sge.length = size; + sge.lkey = qp->get_lkey(); + + struct ibv_send_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = wr_id; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_ATOMIC_CMP_AND_SWP; + wr.send_flags = IBV_SEND_SIGNALED; + wr.wr.atomic.remote_addr = remote_addr; + wr.wr.atomic.rkey = qp->get_rkey(); + wr.wr.atomic.compare_add = compare_value; // Expected value + wr.wr.atomic.swap = swap_value; // New value if compare succeeds + + struct ibv_send_wr* bad_wr = nullptr; + int ret = ibv_post_send(qp->get_ibv_qp(), &wr, &bad_wr); + + LOG_DEBUG("RDMA Atomic CAS to rank %d: result=%p remote=%lx compare=%lu swap=%lu size=%zu ret=%d", + dst_rank, result_addr, remote_addr, compare_value, swap_value, size, ret); + + return ret; + } + + /** + * @brief Poll completion queue for RDMA operations + * @param dst_rank Destination rank (to poll specific CQ) + * @param max_completions Maximum number of completions to poll + * @return Number of completions polled (negative on error) + */ + int poll_cq(int dst_rank, int max_completions = 1) { + queue_pair* qp = get_qp(dst_rank); + if (!qp) { + return -1; + } + + struct ibv_wc wc[16]; + int num_to_poll = (max_completions < 16) ? max_completions : 16; + int n = ibv_poll_cq(qp->get_ibv_cq(), num_to_poll, wc); + + if (n < 0) { + LOG_ERROR_RANK(rank_, "CQ poll error for QP to rank %d", dst_rank); + return n; + } + + // Check for errors in completions + for (int i = 0; i < n; i++) { + if (wc[i].status != IBV_WC_SUCCESS) { + LOG_ERROR_RANK(rank_, "Work completion failed: status=%d (%s) opcode=%d wr_id=%lu dst_rank=%d", + wc[i].status, ibv_wc_status_str(wc[i].status), wc[i].opcode, wc[i].wr_id, dst_rank); + return -1; + } + LOG_DEBUG_RANK(rank_, "Work completion: status=SUCCESS opcode=%d wr_id=%lu dst_rank=%d", + wc[i].opcode, wc[i].wr_id, dst_rank); + } + + // Dump CQ info and check its healthy + qp->dump_cq_info(); + + //LOG_DEBUG_RANK(rank_, "Polled %d completions from CQ for QP to rank %d", n, dst_rank); + return n; + } + + + + private: + // Bootstrap + std::shared_ptr bootstrap_; + int rank_; + int world_size_; + + // Device configuration + const char* requested_dev_; + struct ibv_context* context_; + struct ibv_pd* pd_orig_; + struct ibv_device_attr device_attr_; + rdma::nic_vendor vendor_; + + // Port configuration + struct ibv_port_attr portinfo_; + union ibv_gid gid_; + int port_; + int gid_index_; + + // Memory registration + struct ibv_mr* heap_mr_; + std::vector rkeys_; // Remote keys from all ranks + uint64_t heap_base_; // Local heap base address + size_t heap_size_; // Local heap size + std::vector remote_heap_bases_; // Heap base addresses from all ranks + + // Queue pairs + std::vector> qps_; + std::vector cqs_; + std::vector dest_info_; + + // Dynamic library handles for vendor-specific libraries + void* mlx5dv_handle_; + void* bnxtdv_handle_; + + // Setup functions (extracted from rocSHMEM) + + // Vendor-specific init + void autodetect_dv_libs() { + LOG_DEBUG("Auto-detecting vendor libraries..."); + + // Try MLX5 + if (mlx5_dv_dl_init() == 0) { + vendor_ = rdma::nic_vendor::MLX5; + LOG_INFO("Detected MLX5 vendor"); + return; + } + + // Try BNXT + if (bnxt_dv_dl_init() == 0) { + vendor_ = rdma::nic_vendor::BNXT; + LOG_INFO("Detected BNXT vendor"); + return; + } + + // Default to standard verbs + vendor_ = rdma::nic_vendor::NONE; + LOG_INFO("Using standard InfiniBand verbs"); + } + + int mlx5_dv_dl_init() { + mlx5dv_handle_ = dlopen("libmlx5.so", RTLD_NOW); + if (!mlx5dv_handle_) { + mlx5dv_handle_ = dlopen("libmlx5.so.1", RTLD_NOW); + } + + if (!mlx5dv_handle_) { + LOG_DEBUG("Could not open libmlx5.so"); + return -1; + } + + return 0; + } + + int bnxt_dv_dl_init() { + bnxtdv_handle_ = dlopen("libbnxt_re.so", RTLD_NOW); + if (!bnxtdv_handle_) { + bnxtdv_handle_ = dlopen("/usr/local/lib/libbnxt_re.so", RTLD_NOW); + } + + if (!bnxtdv_handle_) { + LOG_DEBUG("Could not open libbnxt_re.so"); + return -1; + } + + return 0; + } + + void open_ib_device() { + LOG_INFO("Opening InfiniBand device..."); + + struct ibv_device** device_list = nullptr; + struct ibv_device* device = nullptr; + int num_devices = 0; + + device_list = ibv_get_device_list(&num_devices); + CHECK_NNULL(device_list, "ibv_get_device_list"); + + if (num_devices == 0) { + throw std::runtime_error("No InfiniBand devices found"); + } + + // Select device + device = device_list[0]; // Default to first device + + if (requested_dev_) { + for (int i = 0; i < num_devices; i++) { + const char* dev_name = ibv_get_device_name(device_list[i]); + CHECK_NNULL(dev_name, "ibv_get_device_name"); + + if (strstr(dev_name, requested_dev_)) { + device = device_list[i]; + break; + } + } + } + + // Open device + context_ = ibv_open_device(device); + CHECK_NNULL(context_, "ibv_open_device"); + rdma::dump_ibv_context(context_); + rdma::dump_ibv_device(context_->device); + + // Query device attributes (needed for atomic operations) + int err = ibv_query_device(context_, &device_attr_); + CHECK_ZERO(err, "ibv_query_device"); + LOG_DEBUG("Device attributes: max_qp_rd_atom=%d max_qp_init_rd_atom=%d", + device_attr_.max_qp_rd_atom, device_attr_.max_qp_init_rd_atom); + + // Allocate protection domain + pd_orig_ = ibv_alloc_pd(context_); + CHECK_NNULL(pd_orig_, "ibv_alloc_pd"); + rdma::dump_ibv_pd(pd_orig_); + + // Query port + err = ibv_query_port(context_, port_, &portinfo_); + CHECK_ZERO(err, "ibv_query_port"); + rdma::dump_ibv_port_attr(&portinfo_); + + // Select GID index + select_gid_index(); + + ibv_free_device_list(device_list); + + LOG_INFO("InfiniBand device opened: %s", + ibv_get_device_name(context_->device)); + } + + void select_gid_index() { + LOG_DEBUG("Selecting GID index..."); + + const uint8_t local_gid_prefix[2] = {0xFE, 0x80}; + int selected_gid_index = -1; + union ibv_gid selected_gid; + int err; + + int gid_tbl_len = portinfo_.gid_tbl_len; + + for (int i = 0; i < gid_tbl_len; i++) { + union ibv_gid current_gid; + err = ibv_query_gid(context_, port_, i, ¤t_gid); + if (err != 0) + continue; + + // Skip local GIDs + if (memcmp(current_gid.raw, &local_gid_prefix, 2) == 0) { + continue; + } + + // Use first non-local GID + if (selected_gid_index == -1) { + selected_gid_index = i; + selected_gid = current_gid; + break; + } + } + + if (selected_gid_index == -1) { + selected_gid_index = 0; + err = ibv_query_gid(context_, port_, 0, &selected_gid); + CHECK_ZERO(err, "ibv_query_gid"); + } + + gid_index_ = selected_gid_index; + gid_ = selected_gid; + + LOG_DEBUG("Selected GID index: %d", gid_index_); + } + + void create_queues() { + LOG_DEBUG("Creating queues..."); + + int ncqes = 64; // Number of CQ entries + int sq_length = 64; // Send queue length // TODO: FIX THAT + + // Resize vectors + dest_info_.resize(world_size_); + cqs_.resize(world_size_); + qps_.resize(world_size_); + + // Create CQs and QPs + create_cqs(ncqes); + create_qps(sq_length); + + LOG_INFO("Created %d queue pairs", world_size_); + } + + void create_cqs(int ncqes) { + LOG_DEBUG("Creating completion queues: ncqes=%d", ncqes); + + for (int i = 0; i < world_size_; i++) { + cqs_[i] = ibv_create_cq(context_, ncqes, nullptr, nullptr, 0); + CHECK_NNULL(cqs_[i], "ibv_create_cq"); + } + } + + void create_qps(int sq_length) { + LOG_DEBUG("Creating queue pairs: sq_length=%d", sq_length); + + struct ibv_qp_init_attr attr; + memset(&attr, 0, sizeof(attr)); + + attr.cap.max_send_wr = sq_length; + attr.cap.max_send_sge = 1; + attr.cap.max_inline_data = 8; + attr.sq_sig_all = 0; + attr.qp_type = IBV_QPT_RC; + + for (int i = 0; i < world_size_; i++) { + attr.send_cq = cqs_[i]; + attr.recv_cq = cqs_[i]; + + struct ibv_qp* qp = ibv_create_qp(pd_orig_, &attr); + CHECK_NNULL(qp, "ibv_create_qp"); + + qps_[i] = std::make_unique(qp, cqs_[i], i, vendor_); + } + } + + void exchange_qp_dest_info() { + LOG_DEBUG("Exchanging QP destination info..."); + + // Fill local dest info + for (int i = 0; i < world_size_; i++) { + dest_info_[i].lid = portinfo_.lid; + dest_info_[i].qpn = qps_[i]->get_qp_num(); + dest_info_[i].psn = 0; + dest_info_[i].gid = gid_; + } + + // All-gather dest info + bootstrap_->all_gather(dest_info_.data(), sizeof(rdma::qp_dest_info_t)); + + LOG_DEBUG("QP destination info exchanged"); + } + + void modify_qps_reset_to_init() { + LOG_DEBUG("Transitioning QPs: RESET -> INIT"); + + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = 0; + attr.port_num = port_; + attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE | + IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC; + + int attr_mask = + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS; + + for (int i = 0; i < world_size_; i++) { + int err = ibv_modify_qp(qps_[i]->get_ibv_qp(), &attr, attr_mask); + CHECK_ZERO(err, "modify_qp (RESET->INIT)"); + } + } + + void modify_qps_init_to_rtr() { + LOG_DEBUG("Transitioning QPs: INIT -> RTR"); + + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + + attr.qp_state = IBV_QPS_RTR; + attr.path_mtu = portinfo_.active_mtu; + attr.min_rnr_timer = 12; + attr.max_dest_rd_atomic = device_attr_.max_qp_rd_atom; // Use device capability + attr.ah_attr.port_num = port_; + + if (portinfo_.link_layer == IBV_LINK_LAYER_ETHERNET) { + attr.ah_attr.grh.sgid_index = gid_index_; + attr.ah_attr.is_global = 1; + attr.ah_attr.grh.hop_limit = 1; + attr.ah_attr.sl = 1; + attr.ah_attr.grh.traffic_class = 0; + } + + int attr_mask = IBV_QP_STATE | IBV_QP_PATH_MTU | IBV_QP_RQ_PSN | + IBV_QP_DEST_QPN | IBV_QP_AV | IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER; + + for (int i = 0; i < world_size_; i++) { + attr.rq_psn = dest_info_[i].psn; + attr.dest_qp_num = dest_info_[i].qpn; + + if (portinfo_.link_layer == IBV_LINK_LAYER_ETHERNET) { + memcpy(&attr.ah_attr.grh.dgid, &dest_info_[i].gid, 16); + } else { + attr.ah_attr.dlid = dest_info_[i].lid; + } + + int err = ibv_modify_qp(qps_[i]->get_ibv_qp(), &attr, attr_mask); + CHECK_ZERO(err, "modify_qp (INIT->RTR)"); + } + } + + void modify_qps_rtr_to_rts() { + LOG_DEBUG("Transitioning QPs: RTR -> RTS"); + + struct ibv_qp_attr attr; + memset(&attr, 0, sizeof(attr)); + + attr.qp_state = IBV_QPS_RTS; + attr.timeout = 14; + attr.retry_cnt = 7; + attr.rnr_retry = 7; + attr.max_rd_atomic = device_attr_.max_qp_init_rd_atom; // Use device capability + + int attr_mask = IBV_QP_STATE | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC | + IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY; + + for (int i = 0; i < world_size_; i++) { + attr.sq_psn = dest_info_[i].psn; + + int err = ibv_modify_qp(qps_[i]->get_ibv_qp(), &attr, attr_mask); + CHECK_ZERO(err, "modify_qp (RTR->RTS)"); + } + } + +}; + +} // namespace iris diff --git a/iris/experimental/iris_rdma/src/queue.hpp b/iris/experimental/iris_rdma/src/queue.hpp new file mode 100644 index 00000000..0d3b30c2 --- /dev/null +++ b/iris/experimental/iris_rdma/src/queue.hpp @@ -0,0 +1,153 @@ +// GPU-to-CPU Queue - C++ Host Side +// Exposes queue pointer to Python/Triton + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace iris { +namespace rdma { + +// Operation types - simplified for Iris +enum class operation_type : uint8_t { + NOP = 0, + PUT = 1, // RDMA write + GET = 2, // RDMA read + FLUSH = 3, // Flush connection + ATOMIC_ADD = 4, // Atomic add + ATOMIC_EXCH = 5, // Atomic exchange + ATOMIC_CAS = 6, // Atomic compare-and-swap +}; + +// Work item structure - metadata only, no data storage +// Data is stored in the registered symmetric heap +struct alignas(16) work_item_header_t { + uint64_t dst_ptr; // Destination pointer (where to write on remote) + uint64_t src_ptr; // Source pointer (offset in local registered heap) + uint32_t size_bytes; // Size in bytes to transfer (WRITE LAST as ready flag) + uint16_t rank; // Remote rank + uint8_t op_type; // Operation type (see operation_type enum) + uint8_t reserved; // Reserved for future use +}; + +// Note: Completion is signaled by tail pointer advancement, not a flag +struct alignas(16) work_item_t { + work_item_header_t header; // 32 bytes (0-31, padded due to alignas(16)) + // For atomic operations: operand values + uint64_t atomic_operand; // Value to add/exchange (offset 32) + uint64_t atomic_compare; // For CAS: compare value (offset 40) + // Total size: 48 bytes +}; + +// Queue state visible to both CPU and GPU +struct queue_state_t { + work_item_t* items; // Queue buffer (pinned host memory) + uint64_t* head; // Head pointer (device memory, GPU writes) + uint64_t* tail; // Tail pointer (host memory, CPU writes, GPU reads) + uint64_t* tailCache; // Cached tail (device memory) + int32_t size; // Queue capacity +}; + +// CPU-side queue management +class queue { + public: + explicit queue(int size = 512) : size_(size) { + // Allocate pinned memory for queue_state_t struct (GPU needs to read this) + hipHostMalloc(&state_, sizeof(queue_state_t)); + + // Allocate pinned memory for queue items + hipHostMalloc(&state_->items, size * sizeof(work_item_t)); + memset(state_->items, 0, size * sizeof(work_item_t)); + + // Allocate device memory for head + hipMalloc(&state_->head, sizeof(uint64_t)); + hipMemset(state_->head, 0, sizeof(uint64_t)); + + // Allocate pinned memory for tail (CPU writes, GPU reads) + hipHostMalloc(&state_->tail, sizeof(uint64_t)); + *state_->tail = 0; + + // Allocate device memory for tail cache + hipMalloc(&state_->tailCache, sizeof(uint64_t)); + hipMemset(state_->tailCache, 0, sizeof(uint64_t)); + + state_->size = size; + } + + ~queue() { + hipHostFree(state_->items); + hipFree(state_->head); + hipHostFree(state_->tail); + hipFree(state_->tailCache); + hipHostFree(state_); + } + + // Get raw pointer to queue state for Triton + queue_state_t* get_queue_ptr() { return state_; } + + // Poll for new work item (non-blocking) + bool poll(work_item_t& item) { + uint64_t currentTail = *state_->tail; + work_item_t* ptr = &state_->items[currentTail % size_]; + + // Atomic load of size_bytes (acquire semantics) - use as ready flag + // size_bytes == 0 means slot is empty/processed + uint32_t size_bytes = + reinterpret_cast*>(&ptr->header.size_bytes)->load(std::memory_order_acquire); + + // Check if slot is ready + if (size_bytes == 0) { + return false; // Queue empty + } + + // Copy entire work item (just header now, no data array) + memcpy(&item, ptr, sizeof(work_item_t)); + + return true; + } + + // Mark work item as processed + void pop() { + uint64_t currentTail = *state_->tail; + + // Clear the size_bytes to mark as processed + state_->items[currentTail % size_].header.size_bytes = 0; + + // Advance tail with release semantics (GPU will reload this into tailCache) + uint64_t newTail = currentTail + 1; + reinterpret_cast*>(state_->tail)->store(newTail, std::memory_order_release); + } + + // Get queue statistics + uint64_t get_tail() const { return *state_->tail; } + + uint64_t get_head() const { + uint64_t h; + hipMemcpy(&h, state_->head, sizeof(uint64_t), hipMemcpyDeviceToHost); + return h; + } + + int get_size() const { return size_; } + + // Check if queue is empty (all work processed) + bool is_empty() const { + uint64_t h; + hipMemcpy(&h, state_->head, sizeof(uint64_t), hipMemcpyDeviceToHost); + return h == *state_->tail; + } + + private: + queue_state_t* state_; + int size_; +}; + +} // namespace rdma +} // namespace iris diff --git a/iris/experimental/iris_rdma/src/queue_pair.hpp b/iris/experimental/iris_rdma/src/queue_pair.hpp new file mode 100644 index 00000000..b3fea24a --- /dev/null +++ b/iris/experimental/iris_rdma/src/queue_pair.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include "ibv_utils.hpp" + +namespace iris { + +/** + * @brief Simplified Queue Pair wrapper for host-side operations + * + * Unlike the full rocSHMEM QueuePair, this version only maintains + * metadata needed for RDMA operations from Python/host code. + */ +class queue_pair { + public: + /** + * @brief Constructor + * @param qp InfiniBand queue pair + * @param cq InfiniBand completion queue + * @param dst_rank Destination rank for this QP + * @param vendor NIC vendor type + */ + inline queue_pair(struct ibv_qp* qp, + struct ibv_cq* cq, + int dst_rank, + rdma::nic_vendor vendor) + : qp_(qp), + cq_(cq), + dst_rank_(dst_rank), + vendor_(vendor), + lkey_(0), + rkey_(0) { + CHECK_NNULL(qp_, "QueuePair: ibv_qp"); + CHECK_NNULL(cq_, "QueuePair: ibv_cq"); + qp_num_ = qp_->qp_num; + LOG_DEBUG("queue_pair created: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); + } + + /** + * @brief Destructor + */ + inline ~queue_pair() { + LOG_DEBUG("queue_pair destroyed: qp_num=%u, dst_rank=%d", qp_num_, dst_rank_); + } + + /** + * @brief Get QP number + */ + uint32_t get_qp_num() const { return qp_num_; } + + /** + * @brief Get local key for memory region + */ + uint32_t get_lkey() const { return lkey_; } + + /** + * @brief Get remote key for destination rank + */ + uint32_t get_rkey() const { return rkey_; } + + /** + * @brief Get destination rank + */ + int get_dst_rank() const { return dst_rank_; } + + /** + * @brief Set remote key (after exchange) + */ + void set_rkey(uint32_t rkey) { rkey_ = rkey; } + + /** + * @brief Set local key (from memory registration) + */ + void set_lkey(uint32_t lkey) { lkey_ = lkey; } + + /** + * @brief Get underlying ibv_qp pointer + */ + struct ibv_qp* get_ibv_qp() { return qp_; } + + /** + * @brief Get underlying ibv_cq pointer + */ + struct ibv_cq* get_ibv_cq() { return cq_; } + + /** + * @brief Get QP info for Python + */ + inline rdma::qp_info_t get_info() const { + rdma::qp_info_t info; + info.qp_num = qp_num_; + info.lkey = lkey_; + info.rkey = rkey_; + info.dst_rank = dst_rank_; + return info; + } + + + void dump_cq_info() const { + LOG_DEBUG("cq: %p", cq_); + LOG_DEBUG("handle: %u", cq_->channel); + LOG_DEBUG("cq_context: %p", cq_->cq_context); + LOG_DEBUG("context: %p", cq_->context); + LOG_DEBUG("cqe: %u", cq_->cqe); + LOG_DEBUG("comp_events_completed: %u", cq_->comp_events_completed); + LOG_DEBUG("async_events_completed: %u", cq_->async_events_completed); + + } + private: + struct ibv_qp* qp_; + struct ibv_cq* cq_; + int dst_rank_; + rdma::nic_vendor vendor_; + + uint32_t qp_num_; + uint32_t lkey_; + uint32_t rkey_; +}; + +} // namespace iris + diff --git a/iris/experimental/iris_rdma/src/torch_bootstrap.hpp b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp new file mode 100644 index 00000000..081336b3 --- /dev/null +++ b/iris/experimental/iris_rdma/src/torch_bootstrap.hpp @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ibv_utils.hpp" + +namespace iris { +namespace rdma { + +/** + * @brief Bootstrap implementation using PyTorch Distributed + * + * Wraps PyTorch's c10d process group to provide synchronization + * primitives needed for InfiniBand setup (all_gather, barrier) + */ +class torch_bootstrap { + public: + /** + * @brief Constructor + * @param process_group PyTorch distributed process group + */ + inline explicit torch_bootstrap(c10::intrusive_ptr process_group) + : process_group_(process_group) { + if (!process_group_) { + throw std::runtime_error("Process group cannot be null"); + } + rank_ = process_group_->getRank(); + world_size_ = process_group_->getSize(); + LOG_INFO("torch_bootstrap initialized: rank=%d, world_size=%d", rank_, world_size_); + } + + /** + * @brief Get rank of current process + */ + int get_rank() const { return rank_; } + + /** + * @brief Get total number of ranks + */ + int get_world_size() const { return world_size_; } + + /** + * @brief All-gather operation + * + * Gathers data from all ranks. Each rank contributes 'size' bytes + * starting at allData[rank * size]. + * + * @param allData Buffer to hold all gathered data (world_size * size bytes) + * @param size Size of data contributed by each rank + */ + inline void all_gather(void* allData, int size) { + auto cpu_options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto cuda_options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + + auto cpu_input = torch::from_blob( + static_cast(allData) + rank_ * size, {size}, cpu_options); + auto input = cpu_input.to(torch::kCUDA); + + std::vector output_tensors; + for (int i = 0; i < world_size_; i++) { + output_tensors.push_back(torch::empty({size}, cuda_options)); + } + + std::vector> output_tensor_lists = {output_tensors}; + std::vector input_tensors = {input}; + auto work = process_group_->allgather(output_tensor_lists, input_tensors); + work->wait(); + + for (int i = 0; i < world_size_; i++) { + auto cpu_output = output_tensors[i].to(torch::kCPU); + std::memcpy(static_cast(allData) + i * size, + cpu_output.data_ptr(), size); + } + LOG_DEBUG("AllGather completed: %d bytes per rank", size); + } + + /** + * @brief Barrier synchronization + * + * Blocks until all ranks reach the barrier + */ + inline void barrier() { + auto work = process_group_->barrier(); + work->wait(); + LOG_DEBUG("Barrier completed"); + } + + /** + * @brief Point-to-point send (optional, not needed for basic setup) + */ + inline void send(void* data, int size, int peer, int tag) { + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto tensor = torch::from_blob(static_cast(data), {size}, options); + std::vector tensors = {tensor}; + auto work = process_group_->send(tensors, peer, tag); + work->wait(); + } + + /** + * @brief Point-to-point receive (optional, not needed for basic setup) + */ + inline void recv(void* data, int size, int peer, int tag) { + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto tensor = torch::from_blob(static_cast(data), {size}, options); + std::vector tensors = {tensor}; + auto work = process_group_->recv(tensors, peer, tag); + work->wait(); + } + + private: + c10::intrusive_ptr process_group_; + int rank_; + int world_size_; +}; + +} // namespace rdma +} // namespace iris + diff --git a/rebuild.sh b/rebuild.sh new file mode 100755 index 00000000..9aa6c1d9 --- /dev/null +++ b/rebuild.sh @@ -0,0 +1,4 @@ +#!/bin/bash + + +pip install -e . --no-build-isolation \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100755 index 00000000..0563b8b5 --- /dev/null +++ b/run.sh @@ -0,0 +1,7 @@ +#!/bin/bash + + +export IRIS_RDMA_POLL_MAX_ATTEMPTS=1000 +export IRIS_LOG_LEVEL=DEBUG +export IRIS_DEBUG_DATA=1 +torchrun --nproc_per_node=2 examples/24_rdma_atomic_add/rdma_atomic_add.py \ No newline at end of file diff --git a/setup.py b/setup.py index 69832461..59fa1192 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,108 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -from setuptools import setup +import os +import subprocess +import sys +from pathlib import Path +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext + + +class CMakeExtension(Extension): + """Extension that uses CMake to build""" + def __init__(self, name, sourcedir=""): + super().__init__(name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +class CMakeBuild(build_ext): + """Custom build_ext command that runs CMake""" + + def run(self): + # Check if CMake is available + try: + subprocess.check_output(["cmake", "--version"]) + except OSError: + raise RuntimeError("CMake must be installed to build RDMA extensions") + + # Build each extension + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + if not isinstance(ext, CMakeExtension): + return super().build_extension(ext) + + extdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute() + + # CMake configuration arguments + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + "-DCMAKE_BUILD_TYPE=Release", + ] + + # Build arguments + build_args = ["--config", "Release"] + + # Parallel build + if hasattr(os, "cpu_count"): + build_args += [f"-j{os.cpu_count()}"] + + # Create build directory + build_temp = Path(self.build_temp) / ext.name + build_temp.mkdir(parents=True, exist_ok=True) + + # Run CMake + subprocess.check_call( + ["cmake", ext.sourcedir] + cmake_args, + cwd=build_temp + ) + + # Build + subprocess.check_call( + ["cmake", "--build", "."] + build_args, + cwd=build_temp + ) + + +# Check if InfiniBand libraries are available (optional RDMA support) +def has_infiniband(): + """Check if InfiniBand development libraries are available""" + try: + result = subprocess.run( + ["pkg-config", "--exists", "libibverbs"], + capture_output=True + ) + return result.returncode == 0 + except FileNotFoundError: + # pkg-config not available, try to find library directly + for path in ["/usr/lib", "/usr/lib64", "/usr/local/lib"]: + if os.path.exists(os.path.join(path, "libibverbs.so")): + return True + return False + + +# Build RDMA extension if InfiniBand is available +ext_modules = [] +if has_infiniband(): + print("InfiniBand libraries detected - building RDMA backend") + rdma_ext = CMakeExtension( + "iris.experimental._iris_rdma_backend", + sourcedir="iris/experimental/iris_rdma" + ) + ext_modules.append(rdma_ext) +else: + print("InfiniBand libraries not found - skipping RDMA backend") + print("To enable RDMA support, install: libibverbs-dev (Ubuntu/Debian) or rdma-core-devel (RHEL/CentOS)") + # This setup.py provides backward compatibility for legacy metadata fields # that don't map directly from pyproject.toml's modern PEP 621 format. setup( url="https://rocm.github.io/iris/", author="Muhammad Awad, Muhammad Osama, Brandon Potter", + ext_modules=ext_modules, + cmdclass={"build_ext": CMakeBuild} if ext_modules else {}, )