Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@
"triton.language",
"numpy",
"iris._distributed_helpers",
"iris.hip",
]

# Napoleon settings for Google/NumPy docstring parsing
Expand Down
18 changes: 18 additions & 0 deletions docs/reference/api-hip-module.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# HIP Module API

Low-level HIP runtime integration for AMD GPU device management and memory operations.

This module provides public APIs for querying device attributes.

## Device Attributes

### get_wall_clock_rate
```{eval-rst}
.. autofunction:: iris.hip.get_wall_clock_rate
```

### get_num_xcc
```{eval-rst}
.. autofunction:: iris.hip.get_num_xcc
```

2 changes: 2 additions & 0 deletions docs/reference/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ Explore Iris APIs. The reference is broken down into focused sections to mirror
- The `Iris` class itself (constructor and helper utilities)
- Tensor-like creation methods on the `Iris` context
- Triton device-side functions for remote memory ops and atomics
- HIP runtime integration for low-level device management

Use the links below to navigate:

- [Iris Class (ctor & helpers)](api-iris-class.md)
- [Tensor Creation](api-tensor-creation.md)
- [Triton Device Functions](api-device-functions.md)
- [HIP Module](api-hip-module.md)

1 change: 1 addition & 0 deletions docs/sphinx/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ subtrees:
- file: reference/api-iris-class.md
- file: reference/api-tensor-creation.md
- file: reference/api-device-functions.md
- file: reference/api-hip-module.md
223 changes: 199 additions & 24 deletions iris/hip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

"""
HIP Runtime Integration Module

This module provides low-level HIP runtime integration for AMD GPUs,
offering Python bindings to essential HIP runtime functions through ctypes.
It enables device management, memory operations, and inter-process communication
for multi-GPU programming.

Key Features:
- Device enumeration and management
- IPC (Inter-Process Communication) memory handles
- Device attribute queries (compute units, architecture, XCC count)
- Fine-grained and coarse-grained memory allocation
- ROCm version detection

Example:
>>> import iris.hip as hip
>>> num_devices = hip.count_devices()
>>> hip.set_device(0)
>>> cu_count = hip.get_cu_count()
"""

import ctypes
import numpy as np
import sys
Expand All @@ -11,17 +33,70 @@


def hip_try(err):
"""
Check HIP error codes and raise RuntimeError if an error occurred.

Args:
err (int): HIP error code returned from a HIP runtime function.

Raises:
RuntimeError: If err is non-zero, with a descriptive error message.

Example:
>>> hip_try(0) # No error, returns silently
>>> hip_try(1) # Raises RuntimeError with HIP error message
"""
if err != 0:
hip_runtime.hipGetErrorString.restype = ctypes.c_char_p
error_string = hip_runtime.hipGetErrorString(ctypes.c_int(err)).decode("utf-8")
raise RuntimeError(f"HIP error code {err}: {error_string}")


class hipIpcMemHandle_t(ctypes.Structure):
"""
HIP IPC (Inter-Process Communication) memory handle structure.

This structure represents an opaque handle used for sharing memory
between processes on different GPUs. The handle contains 64 bytes
of reserved data that uniquely identifies the shared memory region.

Attributes:
reserved (ctypes.c_char * 64): Reserved bytes containing the handle data.

Example:
>>> handle = hipIpcMemHandle_t()
>>> # Use with get_ipc_handle and open_ipc_handle
"""

_fields_ = [("reserved", ctypes.c_char * 64)]


def open_ipc_handle(ipc_handle_data, rank):
"""
Open an IPC memory handle to access shared memory from another process.

This function takes an IPC memory handle (obtained via get_ipc_handle) and
opens it to allow the current process to access the shared memory region.
The memory is opened with lazy peer access enabled.

Args:
ipc_handle_data (numpy.ndarray): A 64-element uint8 numpy array containing
the IPC handle data.
rank (int): The rank ID of the process opening the handle (used for logging/debugging).

Returns:
int: The pointer value (as Python int) to the opened shared memory.

Raises:
ValueError: If ipc_handle_data is not a 64-element uint8 numpy array.
TypeError: If ipc_handle_data is not a numpy.ndarray.
RuntimeError: If the HIP runtime call fails.

Example:
>>> # On process with rank 1, get the handle from process 0
>>> ipc_data = all_ipc_handles[0] # From distributed communication
>>> ptr = open_ipc_handle(ipc_data, rank=1)
"""
ptr = ctypes.c_void_p()
hipIpcMemLazyEnablePeerAccess = ctypes.c_uint(1)
hip_runtime.hipIpcOpenMemHandle.argtypes = [
Expand Down Expand Up @@ -55,28 +130,104 @@ def open_ipc_handle(ipc_handle_data, rank):


def get_ipc_handle(ptr, rank):
"""
Get an IPC memory handle for a memory pointer to share with other processes.

This function creates an IPC handle that can be shared with other processes
to allow them to access the memory pointed to by ptr.

Args:
ptr (ctypes.c_void_p): Pointer to the memory region to share.
rank (int): The rank ID of the process creating the handle (used for logging/debugging).

Returns:
hipIpcMemHandle_t: An IPC memory handle that can be shared with other processes.

Raises:
RuntimeError: If the HIP runtime call fails.

Example:
>>> import ctypes
>>> heap_ptr = ctypes.c_void_p(tensor.data_ptr())
>>> handle = get_ipc_handle(heap_ptr, rank=0)
"""
ipc_handle = hipIpcMemHandle_t()
hip_try(hip_runtime.hipIpcGetMemHandle(ctypes.byref(ipc_handle), ptr))
return ipc_handle


def count_devices():
"""
Get the number of available HIP devices (GPUs).

Returns:
int: The number of HIP-capable devices available on the system.

Raises:
RuntimeError: If the HIP runtime call fails.

Example:
>>> num_gpus = count_devices()
>>> print(f"Found {num_gpus} GPU(s)")
"""
device_count = ctypes.c_int()
hip_try(hip_runtime.hipGetDeviceCount(ctypes.byref(device_count)))
return device_count.value


def set_device(gpu_id):
"""
Set the current HIP device for subsequent operations.

Args:
gpu_id (int): The device ID to set as the current device (0-indexed).

Raises:
RuntimeError: If the HIP runtime call fails or the device ID is invalid.

Example:
>>> set_device(0) # Use GPU 0
>>> set_device(1) # Switch to GPU 1
"""
hip_try(hip_runtime.hipSetDevice(gpu_id))


def get_device_id():
"""
Get the currently active HIP device ID.

Returns:
int: The ID of the currently active HIP device.

Raises:
RuntimeError: If the HIP runtime call fails.

Example:
>>> current_device = get_device_id()
>>> print(f"Using GPU {current_device}")
"""
device_id = ctypes.c_int()
hip_try(hip_runtime.hipGetDevice(ctypes.byref(device_id)))
return device_id.value


def get_cu_count(device_id=None):
"""
Get the number of compute units (CUs) for a HIP device.

Args:
device_id (int, optional): The device ID to query. If None, uses the current device.

Returns:
int: The number of compute units on the specified device.

Raises:
RuntimeError: If the HIP runtime call fails.

Example:
>>> cu_count = get_cu_count() # Current device
>>> cu_count_gpu1 = get_cu_count(device_id=1) # Specific device
"""
if device_id is None:
device_id = get_device_id()

Expand All @@ -89,6 +240,20 @@ def get_cu_count(device_id=None):


def get_rocm_version():
"""
Get the installed ROCm version.

Returns:
tuple: A tuple of (major, minor) version numbers as integers.

Raises:
FileNotFoundError: If the ROCm version file is not found.
IndexError: If the version file format is unexpected.

Example:
>>> major, minor = get_rocm_version()
>>> print(f"ROCm version: {major}.{minor}")
"""
major, minor = -1, -1
with open("/opt/rocm/.info/version", "r") as version_file:
version = version_file.readline().strip()
Expand All @@ -98,6 +263,22 @@ def get_rocm_version():


def get_wall_clock_rate(device_id):
"""
Get the wall clock rate (GPU clock frequency) for a HIP device.

Args:
device_id (int): The device ID to query.

Returns:
int: The wall clock rate in kHz.

Raises:
RuntimeError: If the HIP runtime call fails.

Example:
>>> clock_rate = get_wall_clock_rate(0)
>>> print(f"GPU clock rate: {clock_rate} kHz")
"""
hipDeviceAttributeWallClockRate = 10017
wall_clock_rate = ctypes.c_int()
status = hip_runtime.hipDeviceGetAttribute(
Expand All @@ -107,15 +288,26 @@ def get_wall_clock_rate(device_id):
return wall_clock_rate.value


def get_arch_string(device_id=None):
if device_id is None:
device_id = get_device_id()
arch_full = torch.cuda.get_device_properties(device_id).gcnArchName
arch_name = arch_full.split(":")[0]
return arch_name
def get_num_xcc(device_id=None):
"""
Get the number of XCCs (Compute Dies) for a HIP device.

XCC (eXtended Compute Complex) refers to the compute dies in MI300 series GPUs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot XCC Accelerated Compute Core

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated XCC acronym definition to "Accelerated Compute Core" in the docstring. Commit: a36e9a1

For ROCm versions before 7.0, returns a default value of 8.

def get_num_xcc(device_id=None):
Args:
device_id (int, optional): The device ID to query. If None, uses the current device.

Returns:
int: The number of XCCs on the device.

Raises:
RuntimeError: If the HIP runtime call fails.

Example:
>>> xcc_count = get_num_xcc()
>>> print(f"Number of XCCs: {xcc_count}")
"""
if device_id is None:
device_id = get_device_id()
rocm_major, _ = get_rocm_version()
Expand All @@ -125,20 +317,3 @@ def get_num_xcc(device_id=None):
xcc_count = ctypes.c_int()
hip_try(hip_runtime.hipDeviceGetAttribute(ctypes.byref(xcc_count), hipDeviceAttributeNumberOfXccs, device_id))
return xcc_count.value


def malloc_fine_grained(size):
hipDeviceMallocFinegrained = 0x1
ptr = ctypes.c_void_p()
hip_try(hip_runtime.hipExtMallocWithFlags(ctypes.byref(ptr), size, hipDeviceMallocFinegrained))
return ptr


def hip_malloc(size):
ptr = ctypes.c_void_p()
hip_try(hip_runtime.hipMalloc(ctypes.byref(ptr), size))
return ptr


def hip_free(ptr):
hip_try(hip_runtime.hipFree(ptr))