Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions flashinfer/comm/mnnvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def Get_size(self) -> int: ...
@abstractmethod
def allgather(self, data: int) -> List[int]: ...

@abstractmethod
def barrier(self) -> None: ...

@abstractmethod
def Split(self, color: int, key: int) -> "CommBackend": ...

Expand Down Expand Up @@ -209,6 +212,9 @@ def Get_size(self) -> int:
def allgather(self, data: int) -> List[int]:
return self._mpicomm.allgather(data)

def barrier(self):
self._mpicomm.Barrier()

def Split(self, color: int, key: int) -> CommBackend:
self._mpicomm = self._mpicomm.Split(color, key)
return MPIBackend() # Returns new adapter
Expand Down Expand Up @@ -555,6 +561,7 @@ def __init__(
group_rank: int,
device_idx: int,
is_multi_node: bool = True,
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
):
cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx))

Expand Down Expand Up @@ -631,7 +638,7 @@ def __init__(
"[McastDeviceMemory] Device does not support fabric handle."
)

self._alloc_mn_mcast_mem(buf_size)
self._alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer)
else:
# For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
raise NotImplementedError("Single-node NVLS allocation not implemented yet")
Expand Down Expand Up @@ -753,7 +760,9 @@ def get_world_size(self) -> int:
"""Get the total number of devices in the group"""
return self.group_size

def _alloc_mn_mcast_mem(self, buf_size: int):
def _alloc_mn_mcast_mem(
self, buf_size: int, comm_backend_for_handle_transfer: Any = None
):
"""Allocate multi-node multicast memory using MNNVL"""

# Verify CUDA context
Expand All @@ -766,10 +775,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
)
except Exception as e:
print(f"Error checking CUDA context: {e}")

# Get MPI communicator
comm = MpiComm()

if comm_backend_for_handle_transfer is None:
comm = MpiComm()
else:
comm = comm_backend_for_handle_transfer
# Set up allocation properties
handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC

Expand Down Expand Up @@ -969,6 +978,7 @@ def __init__(
group_rank: int,
device: torch.device,
mn_nvlink: bool = True,
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
):
"""
Constructor for McastGpuBuffer.
Expand All @@ -979,9 +989,15 @@ def __init__(
group_rank: The rank of the local process within the group
device: The CUDA device for buffer allocation
mn_nvlink: Flag indicating if multi-node NVLink is used
comm_backend_for_handle_transfer: Communication backend for handle transfer
"""
self.mcast_device_memory = McastDeviceMemory(
buf_size, group_size, group_rank, device.index, mn_nvlink
buf_size,
group_size,
group_rank,
device.index,
mn_nvlink,
comm_backend_for_handle_transfer,
)
self.buf_size = buf_size
self.local_device = device
Expand Down
14 changes: 11 additions & 3 deletions flashinfer/comm/trtllm_mnnvl_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..jit import gen_trtllm_mnnvl_comm_module
from ..utils import register_custom_op
from .mnnvl import McastGPUBuffer
from .mnnvl import McastGPUBuffer, CommBackend


def mpi_barrier():
Expand Down Expand Up @@ -122,7 +122,10 @@ def trtllm_mnnvl_rmsnorm(


def get_allreduce_mnnvl_workspace(
mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None
mapping: Mapping,
dtype: torch.dtype,
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
buffer_size_in_bytes: Optional[int] = None,
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.

Expand All @@ -138,6 +141,7 @@ def get_allreduce_mnnvl_workspace(
Args:
mapping: Tensor parallel mapping configuration containing rank info
dtype: Data type of the tensors being reduced
comm: Optional communication backend for multi-node synchronization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Fix docstring parameter name mismatch.

The docstring references comm: but the actual parameter is named comm_backend_for_handle_transfer. This inconsistency may confuse users and tools that parse docstrings.

Apply this diff:

-        comm: Optional communication backend for multi-node synchronization
+        comm_backend_for_handle_transfer: Optional communication backend for multi-node synchronization
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
comm: Optional communication backend for multi-node synchronization
comm_backend_for_handle_transfer: Optional communication backend for multi-node synchronization
πŸ€– Prompt for AI Agents
In flashinfer/comm/trtllm_mnnvl_ar.py around line 144, the docstring refers to a
parameter named "comm:" while the actual function parameter is
"comm_backend_for_handle_transfer"; update the docstring to use the exact
parameter name "comm_backend_for_handle_transfer" (and adjust its short
description if needed) so the parameter list matches the function signature and
docstring parsers/tools can correctly map the description to the parameter.

buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens

Returns:
Expand Down Expand Up @@ -167,14 +171,18 @@ def get_allreduce_mnnvl_workspace(
mapping.tp_rank,
torch.device("cuda", mapping.local_rank),
mapping.is_multi_node() or force_mn,
comm_backend_for_handle_transfer=comm_backend_for_handle_transfer,
)

# Initialize the unicast buffer with -0.0
mcast_buffer.lamport_initialize(mapping.tp_rank, dtype)

# CPU barrier since we assume this should not be called in cuda graph
torch.cuda.synchronize()
mpi_barrier()
if comm_backend_for_handle_transfer is None:
mpi_barrier()
else:
comm_backend_for_handle_transfer.barrier()

# This is a buffer to maintain the state of this allreduce Op
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]
Expand Down
11 changes: 8 additions & 3 deletions tests/comm/test_trtllm_mnnvl_allreduce.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Check torch version:
from typing import Tuple
from typing import Tuple, Optional

import pytest
import torch
from mpi4py import MPI # Added MPI import

import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar
from flashinfer.comm.mapping import Mapping
from flashinfer.comm.mnnvl import CommBackend, MpiComm

# Use flashinfer.norm.rmsnorm as reference implementation.
from flashinfer.norm import rmsnorm
Expand All @@ -28,6 +29,7 @@ def row_linear_residual_norm_fusion_forward(
unicast_ptr: int,
max_num_elements_mnnvl: int,
buffer_flags_mnnvl: torch.Tensor,
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
):
x = x.cuda()
residual = residual.cuda()
Expand All @@ -36,8 +38,11 @@ def row_linear_residual_norm_fusion_forward(

tensor_parallel_size = mapping.tp_size
tensor_parallel_rank = mapping.tp_rank

MPI.COMM_WORLD.barrier()
if comm_backend_for_handle_transfer is None:
comm = MpiComm()
else:
comm = comm_backend_for_handle_transfer
comm.barrier()

def func(
input,
Expand Down
Loading