Skip to content

Commit 3175f33

Browse files
committed
Add custom communicator for trtllm_mnnvl_ar
1 parent f25929f commit 3175f33

File tree

3 files changed

+360
-8
lines changed

3 files changed

+360
-8
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -547,14 +547,15 @@ def supports_mnnvl() -> bool:
547547

548548
class McastDeviceMemory:
549549
"""Python port of McastDeviceMemory from TensorRT-LLM"""
550-
550+
# config: Optional[MnnvlConfig] = None
551551
def __init__(
552552
self,
553553
buf_size: int,
554554
group_size: int,
555555
group_rank: int,
556556
device_idx: int,
557557
is_multi_node: bool = True,
558+
comm: Optional[CommBackend] = None,
558559
):
559560
cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx))
560561

@@ -631,7 +632,7 @@ def __init__(
631632
"[McastDeviceMemory] Device does not support fabric handle."
632633
)
633634

634-
self._alloc_mn_mcast_mem(buf_size)
635+
self._alloc_mn_mcast_mem(buf_size, comm)
635636
else:
636637
# For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
637638
raise NotImplementedError("Single-node NVLS allocation not implemented yet")
@@ -649,6 +650,7 @@ def __init__(
649650
self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads)
650651
self.uc_ptrs_dev = alloc_and_copy_to_cuda(self.uc_ptrs)
651652

653+
652654
def __del__(self):
653655
"""Destructor - cleanup allocated memory"""
654656

@@ -753,7 +755,7 @@ def get_world_size(self) -> int:
753755
"""Get the total number of devices in the group"""
754756
return self.group_size
755757

756-
def _alloc_mn_mcast_mem(self, buf_size: int):
758+
def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()):
757759
"""Allocate multi-node multicast memory using MNNVL"""
758760

759761
# Verify CUDA context
@@ -768,7 +770,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
768770
print(f"Error checking CUDA context: {e}")
769771

770772
# Get MPI communicator
771-
comm = MpiComm()
773+
# comm = MpiComm()
774+
# comm = McastDeviceMemory.get_comm()
775+
# if config:
776+
# comm = config.comm_backend
777+
# else:
778+
# comm = MpiComm()
772779

773780
# Set up allocation properties
774781
handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
@@ -831,6 +838,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
831838
)
832839

833840
# All-gather fabric handles
841+
print(my_fabric_handle.data)
842+
print(type(my_fabric_handle.data))
843+
# all_fabric_handles=[my_fabric_handle.data] * 4
834844
all_fabric_handles = comm.allgather(my_fabric_handle.data)
835845
cuda.cuCtxSynchronize()
836846

@@ -969,6 +979,7 @@ def __init__(
969979
group_rank: int,
970980
device: torch.device,
971981
mn_nvlink: bool = True,
982+
comm: Optional[CommBackend] = None,
972983
):
973984
"""
974985
Constructor for McastGpuBuffer.
@@ -981,7 +992,7 @@ def __init__(
981992
mn_nvlink: Flag indicating if multi-node NVLink is used
982993
"""
983994
self.mcast_device_memory = McastDeviceMemory(
984-
buf_size, group_size, group_rank, device.index, mn_nvlink
995+
buf_size, group_size, group_rank, device.index, mn_nvlink, comm
985996
)
986997
self.buf_size = buf_size
987998
self.local_device = device

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ..jit import gen_trtllm_mnnvl_comm_module
1717
from ..utils import register_custom_op
18-
from .mnnvl import McastGPUBuffer
18+
from .mnnvl import (McastGPUBuffer, CommBackend)
1919

2020

2121
def mpi_barrier():
@@ -122,7 +122,8 @@ def trtllm_mnnvl_rmsnorm(
122122

123123

124124
def get_allreduce_mnnvl_workspace(
125-
mapping: Mapping, dtype: torch.dtype
125+
mapping: Mapping, dtype: torch.dtype,
126+
comm: Optional[CommBackend] = None,
126127
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
127128
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.
128129
@@ -164,14 +165,18 @@ def get_allreduce_mnnvl_workspace(
164165
mapping.tp_rank,
165166
torch.device("cuda", mapping.local_rank),
166167
mapping.is_multi_node() or force_mn,
168+
comm=comm,
167169
)
168170

169171
# Initialize the unicast buffer with -0.0
170172
mcast_buffer.lamport_initialize(mapping.tp_rank, dtype)
171173

172174
# CPU barrier since we assume this should not be called in cuda graph
173175
torch.cuda.synchronize()
174-
mpi_barrier()
176+
if comm:
177+
comm.barrier()
178+
else:
179+
mpi_barrier()
175180

176181
# This is a buffer to maintain the state of this allreduce Op
177182
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]

0 commit comments

Comments
 (0)