Skip to content

Commit d56be0d

Browse files
authored
Add custom communicator for trtllm_mnnvl_ar (#2056)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added optional communication-backend parameter for multi-node memory and buffer allocation to allow using a provided communicator for handle transfer. * **Bug Fixes / Reliability** * Multi-node synchronization now uses the provided communicator's barrier when available, preserving previous behavior otherwise. * **Tests** * Added end-to-end tests covering custom communication backends and multi-node all-reduce synchronization. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 2439a41 commit d56be0d

File tree

4 files changed

+305
-13
lines changed

4 files changed

+305
-13
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def Get_size(self) -> int: ...
155155
@abstractmethod
156156
def allgather(self, data: int) -> List[int]: ...
157157

158+
@abstractmethod
159+
def barrier(self) -> None: ...
160+
158161
@abstractmethod
159162
def Split(self, color: int, key: int) -> "CommBackend": ...
160163

@@ -209,6 +212,9 @@ def Get_size(self) -> int:
209212
def allgather(self, data: int) -> List[int]:
210213
return self._mpicomm.allgather(data)
211214

215+
def barrier(self):
216+
self._mpicomm.Barrier()
217+
212218
def Split(self, color: int, key: int) -> CommBackend:
213219
self._mpicomm = self._mpicomm.Split(color, key)
214220
return MPIBackend() # Returns new adapter
@@ -555,6 +561,7 @@ def __init__(
555561
group_rank: int,
556562
device_idx: int,
557563
is_multi_node: bool = True,
564+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
558565
):
559566
cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx))
560567

@@ -631,7 +638,7 @@ def __init__(
631638
"[McastDeviceMemory] Device does not support fabric handle."
632639
)
633640

634-
self._alloc_mn_mcast_mem(buf_size)
641+
self._alloc_mn_mcast_mem(buf_size, comm_backend_for_handle_transfer)
635642
else:
636643
# For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
637644
raise NotImplementedError("Single-node NVLS allocation not implemented yet")
@@ -753,7 +760,9 @@ def get_world_size(self) -> int:
753760
"""Get the total number of devices in the group"""
754761
return self.group_size
755762

756-
def _alloc_mn_mcast_mem(self, buf_size: int):
763+
def _alloc_mn_mcast_mem(
764+
self, buf_size: int, comm_backend_for_handle_transfer: Any = None
765+
):
757766
"""Allocate multi-node multicast memory using MNNVL"""
758767

759768
# Verify CUDA context
@@ -766,10 +775,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
766775
)
767776
except Exception as e:
768777
print(f"Error checking CUDA context: {e}")
769-
770-
# Get MPI communicator
771-
comm = MpiComm()
772-
778+
if comm_backend_for_handle_transfer is None:
779+
comm = MpiComm()
780+
else:
781+
comm = comm_backend_for_handle_transfer
773782
# Set up allocation properties
774783
handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
775784

@@ -969,6 +978,7 @@ def __init__(
969978
group_rank: int,
970979
device: torch.device,
971980
mn_nvlink: bool = True,
981+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
972982
):
973983
"""
974984
Constructor for McastGpuBuffer.
@@ -979,9 +989,15 @@ def __init__(
979989
group_rank: The rank of the local process within the group
980990
device: The CUDA device for buffer allocation
981991
mn_nvlink: Flag indicating if multi-node NVLink is used
992+
comm_backend_for_handle_transfer: Communication backend for handle transfer
982993
"""
983994
self.mcast_device_memory = McastDeviceMemory(
984-
buf_size, group_size, group_rank, device.index, mn_nvlink
995+
buf_size,
996+
group_size,
997+
group_rank,
998+
device.index,
999+
mn_nvlink,
1000+
comm_backend_for_handle_transfer,
9851001
)
9861002
self.buf_size = buf_size
9871003
self.local_device = device

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ..jit import gen_trtllm_mnnvl_comm_module
1717
from ..utils import register_custom_op
18-
from .mnnvl import McastGPUBuffer
18+
from .mnnvl import McastGPUBuffer, CommBackend
1919

2020

2121
def mpi_barrier():
@@ -122,7 +122,10 @@ def trtllm_mnnvl_rmsnorm(
122122

123123

124124
def get_allreduce_mnnvl_workspace(
125-
mapping: Mapping, dtype: torch.dtype, buffer_size_in_bytes: Optional[int] = None
125+
mapping: Mapping,
126+
dtype: torch.dtype,
127+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
128+
buffer_size_in_bytes: Optional[int] = None,
126129
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
127130
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.
128131
@@ -138,6 +141,7 @@ def get_allreduce_mnnvl_workspace(
138141
Args:
139142
mapping: Tensor parallel mapping configuration containing rank info
140143
dtype: Data type of the tensors being reduced
144+
comm: Optional communication backend for multi-node synchronization
141145
buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens
142146
143147
Returns:
@@ -167,14 +171,18 @@ def get_allreduce_mnnvl_workspace(
167171
mapping.tp_rank,
168172
torch.device("cuda", mapping.local_rank),
169173
mapping.is_multi_node() or force_mn,
174+
comm_backend_for_handle_transfer=comm_backend_for_handle_transfer,
170175
)
171176

172177
# Initialize the unicast buffer with -0.0
173178
mcast_buffer.lamport_initialize(mapping.tp_rank, dtype)
174179

175180
# CPU barrier since we assume this should not be called in cuda graph
176181
torch.cuda.synchronize()
177-
mpi_barrier()
182+
if comm_backend_for_handle_transfer is None:
183+
mpi_barrier()
184+
else:
185+
comm_backend_for_handle_transfer.barrier()
178186

179187
# This is a buffer to maintain the state of this allreduce Op
180188
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]

tests/comm/test_trtllm_mnnvl_allreduce.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Check torch version:
2-
from typing import Tuple
2+
from typing import Tuple, Optional
33

44
import pytest
55
import torch
66
from mpi4py import MPI # Added MPI import
77

88
import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar
99
from flashinfer.comm.mapping import Mapping
10+
from flashinfer.comm.mnnvl import CommBackend, MpiComm
1011

1112
# Use flashinfer.norm.rmsnorm as reference implementation.
1213
from flashinfer.norm import rmsnorm
@@ -28,6 +29,7 @@ def row_linear_residual_norm_fusion_forward(
2829
unicast_ptr: int,
2930
max_num_elements_mnnvl: int,
3031
buffer_flags_mnnvl: torch.Tensor,
32+
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
3133
):
3234
x = x.cuda()
3335
residual = residual.cuda()
@@ -36,8 +38,11 @@ def row_linear_residual_norm_fusion_forward(
3638

3739
tensor_parallel_size = mapping.tp_size
3840
tensor_parallel_rank = mapping.tp_rank
39-
40-
MPI.COMM_WORLD.barrier()
41+
if comm_backend_for_handle_transfer is None:
42+
comm = MpiComm()
43+
else:
44+
comm = comm_backend_for_handle_transfer
45+
comm.barrier()
4146

4247
def func(
4348
input,

0 commit comments

Comments
 (0)