Skip to content

Commit 602adfe

Browse files
committed
Upd
1 parent b26c69d commit 602adfe

File tree

4 files changed

+66
-115
lines changed

4 files changed

+66
-115
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def Get_size(self) -> int: ...
155155
@abstractmethod
156156
def allgather(self, data: int) -> List[int]: ...
157157

158+
@abstractmethod
159+
def barrier(self) -> None: ...
160+
158161
@abstractmethod
159162
def Split(self, color: int, key: int) -> "CommBackend": ...
160163

@@ -209,6 +212,9 @@ def Get_size(self) -> int:
209212
def allgather(self, data: int) -> List[int]:
210213
return self._mpicomm.allgather(data)
211214

215+
def barrier(self):
216+
self._mpicomm.Barrier()
217+
212218
def Split(self, color: int, key: int) -> CommBackend:
213219
self._mpicomm = self._mpicomm.Split(color, key)
214220
return MPIBackend() # Returns new adapter
@@ -547,14 +553,15 @@ def supports_mnnvl() -> bool:
547553

548554
class McastDeviceMemory:
549555
"""Python port of McastDeviceMemory from TensorRT-LLM"""
556+
550557
def __init__(
551558
self,
552559
buf_size: int,
553560
group_size: int,
554561
group_rank: int,
555562
device_idx: int,
556563
is_multi_node: bool = True,
557-
comm: Optional[CommBackend] = None,
564+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
558565
):
559566
cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx))
560567

@@ -631,7 +638,7 @@ def __init__(
631638
"[McastDeviceMemory] Device does not support fabric handle."
632639
)
633640

634-
self._alloc_mn_mcast_mem(buf_size, comm)
641+
self._alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer)
635642
else:
636643
# For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
637644
raise NotImplementedError("Single-node NVLS allocation not implemented yet")
@@ -753,7 +760,9 @@ def get_world_size(self) -> int:
753760
"""Get the total number of devices in the group"""
754761
return self.group_size
755762

756-
def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()):
763+
def _alloc_mn_mcast_mem(
764+
self, buf_size: int, comm_backend_for_handle_transfer: Any = None
765+
):
757766
"""Allocate multi-node multicast memory using MNNVL"""
758767

759768
# Verify CUDA context
@@ -766,7 +775,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()):
766775
)
767776
except Exception as e:
768777
print(f"Error checking CUDA context: {e}")
769-
778+
if comm_backend_for_handle_transfer is None:
779+
comm = MpiComm()
780+
else:
781+
comm = comm_backend_for_handle_transfer
770782
# Set up allocation properties
771783
handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
772784

@@ -966,7 +978,7 @@ def __init__(
966978
group_rank: int,
967979
device: torch.device,
968980
mn_nvlink: bool = True,
969-
comm: Optional[CommBackend] = None,
981+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
970982
):
971983
"""
972984
Constructor for McastGpuBuffer.
@@ -977,9 +989,15 @@ def __init__(
977989
group_rank: The rank of the local process within the group
978990
device: The CUDA device for buffer allocation
979991
mn_nvlink: Flag indicating if multi-node NVLink is used
992+
comm_backend_for_handle_transfer: Communication backend for handle transfer
980993
"""
981994
self.mcast_device_memory = McastDeviceMemory(
982-
buf_size, group_size, group_rank, device.index, mn_nvlink, comm
995+
buf_size,
996+
group_size,
997+
group_rank,
998+
device.index,
999+
mn_nvlink,
1000+
comm_backend_for_handle_transfer,
9831001
)
9841002
self.buf_size = buf_size
9851003
self.local_device = device

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ..jit import gen_trtllm_mnnvl_comm_module
1717
from ..utils import register_custom_op
18-
from .mnnvl import (McastGPUBuffer, CommBackend)
18+
from .mnnvl import McastGPUBuffer, CommBackend
1919

2020

2121
def mpi_barrier():
@@ -122,9 +122,10 @@ def trtllm_mnnvl_rmsnorm(
122122

123123

124124
def get_allreduce_mnnvl_workspace(
125-
mapping: Mapping, dtype: torch.dtype,
125+
mapping: Mapping,
126+
dtype: torch.dtype,
127+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
126128
buffer_size_in_bytes: Optional[int] = None,
127-
comm: Optional[CommBackend] = None,
128129
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
129130
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.
130131
@@ -140,8 +141,8 @@ def get_allreduce_mnnvl_workspace(
140141
Args:
141142
mapping: Tensor parallel mapping configuration containing rank info
142143
dtype: Data type of the tensors being reduced
143-
buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens
144144
comm: Optional communication backend for multi-node synchronization
145+
buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens
145146
146147
Returns:
147148
Tuple containing:
@@ -170,18 +171,18 @@ def get_allreduce_mnnvl_workspace(
170171
mapping.tp_rank,
171172
torch.device("cuda", mapping.local_rank),
172173
mapping.is_multi_node() or force_mn,
173-
comm=comm,
174+
comm_backend_for_handle_transfer=comm_backend_for_handle_transfer,
174175
)
175176

176177
# Initialize the unicast buffer with -0.0
177178
mcast_buffer.lamport_initialize(mapping.tp_rank, dtype)
178179

179180
# CPU barrier since we assume this should not be called in cuda graph
180181
torch.cuda.synchronize()
181-
if comm:
182-
comm.barrier()
183-
else:
182+
if comm_backend_for_handle_transfer is None:
184183
mpi_barrier()
184+
else:
185+
comm_backend_for_handle_transfer.barrier()
185186

186187
# This is a buffer to maintain the state of this allreduce Op
187188
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]

tests/comm/test_trtllm_mnnvl_allreduce.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Check torch version:
2-
from typing import Tuple
2+
from typing import Tuple, Optional
33

44
import pytest
55
import torch
66
from mpi4py import MPI # Added MPI import
77

88
import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar
99
from flashinfer.comm.mapping import Mapping
10+
from flashinfer.comm.mnnvl import CommBackend, MpiComm
1011

1112
# Use flashinfer.norm.rmsnorm as reference implementation.
1213
from flashinfer.norm import rmsnorm
@@ -28,6 +29,7 @@ def row_linear_residual_norm_fusion_forward(
2829
unicast_ptr: int,
2930
max_num_elements_mnnvl: int,
3031
buffer_flags_mnnvl: torch.Tensor,
32+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
3133
):
3234
x = x.cuda()
3335
residual = residual.cuda()
@@ -36,8 +38,11 @@ def row_linear_residual_norm_fusion_forward(
3638

3739
tensor_parallel_size = mapping.tp_size
3840
tensor_parallel_rank = mapping.tp_rank
39-
40-
MPI.COMM_WORLD.barrier()
41+
if comm_backend_for_handle_transfer is None:
42+
comm = MpiComm()
43+
else:
44+
comm = comm_backend_for_handle_transfer
45+
comm.barrier()
4146

4247
def func(
4348
input,

0 commit comments

Comments
 (0)