Skip to content

Commit b26c69d

Browse files
committed
Add custom communicator for trtllm_mnnvl_ar
Upd
1 parent ba8f3ed commit b26c69d

File tree

3 files changed

+351
-10
lines changed

3 files changed

+351
-10
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -547,14 +547,14 @@ def supports_mnnvl() -> bool:
547547

548548
class McastDeviceMemory:
549549
"""Python port of McastDeviceMemory from TensorRT-LLM"""
550-
551550
def __init__(
552551
self,
553552
buf_size: int,
554553
group_size: int,
555554
group_rank: int,
556555
device_idx: int,
557556
is_multi_node: bool = True,
557+
comm: Optional[CommBackend] = None,
558558
):
559559
cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx))
560560

@@ -631,7 +631,7 @@ def __init__(
631631
"[McastDeviceMemory] Device does not support fabric handle."
632632
)
633633

634-
self._alloc_mn_mcast_mem(buf_size)
634+
self._alloc_mn_mcast_mem(buf_size, comm)
635635
else:
636636
# For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
637637
raise NotImplementedError("Single-node NVLS allocation not implemented yet")
@@ -753,7 +753,7 @@ def get_world_size(self) -> int:
753753
"""Get the total number of devices in the group"""
754754
return self.group_size
755755

756-
def _alloc_mn_mcast_mem(self, buf_size: int):
756+
def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()):
757757
"""Allocate multi-node multicast memory using MNNVL"""
758758

759759
# Verify CUDA context
@@ -767,9 +767,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
767767
except Exception as e:
768768
print(f"Error checking CUDA context: {e}")
769769

770-
# Get MPI communicator
771-
comm = MpiComm()
772-
773770
# Set up allocation properties
774771
handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
775772

@@ -969,6 +966,7 @@ def __init__(
969966
group_rank: int,
970967
device: torch.device,
971968
mn_nvlink: bool = True,
969+
comm: Optional[CommBackend] = None,
972970
):
973971
"""
974972
Constructor for McastGpuBuffer.
@@ -981,7 +979,7 @@ def __init__(
981979
mn_nvlink: Flag indicating if multi-node NVLink is used
982980
"""
983981
self.mcast_device_memory = McastDeviceMemory(
984-
buf_size, group_size, group_rank, device.index, mn_nvlink
982+
buf_size, group_size, group_rank, device.index, mn_nvlink, comm
985983
)
986984
self.buf_size = buf_size
987985
self.local_device = device

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ..jit import gen_trtllm_mnnvl_comm_module
1717
from ..utils import register_custom_op
18-
from .mnnvl import McastGPUBuffer
18+
from .mnnvl import (McastGPUBuffer, CommBackend)
1919

2020

2121
def mpi_barrier():
@@ -122,7 +122,9 @@ def trtllm_mnnvl_rmsnorm(
122122

123123

124124
def get_allreduce_mnnvl_workspace(
125-
mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None
125+
mapping: Mapping, dtype: torch.dtype,
126+
buffer_size_in_bytes: Optional[int] = None,
127+
comm: Optional[CommBackend] = None,
126128
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
127129
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.
128130
@@ -139,6 +141,7 @@ def get_allreduce_mnnvl_workspace(
139141
mapping: Tensor parallel mapping configuration containing rank info
140142
dtype: Data type of the tensors being reduced
141143
buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens
144+
comm: Optional communication backend for multi-node synchronization
142145
143146
Returns:
144147
Tuple containing:
@@ -167,14 +170,18 @@ def get_allreduce_mnnvl_workspace(
167170
mapping.tp_rank,
168171
torch.device("cuda", mapping.local_rank),
169172
mapping.is_multi_node() or force_mn,
173+
comm=comm,
170174
)
171175

172176
# Initialize the unicast buffer with -0.0
173177
mcast_buffer.lamport_initialize(mapping.tp_rank, dtype)
174178

175179
# CPU barrier since we assume this should not be called in cuda graph
176180
torch.cuda.synchronize()
177-
mpi_barrier()
181+
if comm:
182+
comm.barrier()
183+
else:
184+
mpi_barrier()
178185

179186
# This is a buffer to maintain the state of this allreduce Op
180187
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]

0 commit comments

Comments
 (0)