Skip to content

Commit 45a5b82

Browse files
committed
Address review comments.
1 parent 775918d commit 45a5b82

File tree

2 files changed

+69
-16
lines changed

2 files changed

+69
-16
lines changed

csrc/trtllm_mnnvl_allreduce.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ using tvm::ffi::Optional;
2626
} \
2727
}()
2828

29-
// FIXME: is bool flag for oneshot a good idea? Trying to avoid defining a new type/enum at this
30-
// level
3129
void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_ptr,
3230
int64_t buffer_ptrs_dev, int64_t buffer_ptr_local,
3331
TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank,

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,18 @@ class MNNVLAllreduceFusionStrategy(Enum):
3333
AUTO = 99
3434

3535
@staticmethod
36+
<<<<<<< HEAD
3637
def is_one_shot(tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype) -> bool:
38+
=======
39+
def select_strategy(
40+
tp_size: int, num_tokens: int, hidden_dim: int, dtype: torch.dtype
41+
) -> "MNNVLAllreduceFusionStrategy":
42+
>>>>>>> c6ed1472 (Address review comments.)
3743
elem_size = torch.tensor([], dtype=dtype).element_size()
38-
return num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD
44+
if num_tokens * hidden_dim * tp_size * elem_size <= MNNVL_ONE_SHOT_THRESHOLD:
45+
return MNNVLAllreduceFusionStrategy.ONESHOT
46+
else:
47+
return MNNVLAllreduceFusionStrategy.TWOSHOT
3948

4049

4150
# Empirical result calculated from num_tokens * hidden_dim * tp_size * elem_size
@@ -52,15 +61,15 @@ def __init__(
5261
comm_backend: Optional[CommBackend] = None,
5362
):
5463
"""
55-
Initialize the MNNVL Allreduce Fusion Workspace. COMM_WORLD will be used for creating the workspace and synchronization. The process might hang if the intended communication group in mapping is not COMM_WORLD.
64+
Initialize the MNNVL Allreduce Fusion Workspace. comm_backend will be used for creating the workspace and synchronization. If not provided, MPIBackend will be used which will use COMM_WORLD for synchronization.
5665
5766
Args:
5867
mapping: Mapping configuration containing rank info
5968
buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes.
6069
"""
6170
if buffer_size_in_bytes is None:
62-
# Default to 16MB workspace size if not provided
63-
buffer_size_in_bytes = 16 * (1024**2)
71+
# Default to 512MB workspace size if not provided
72+
buffer_size_in_bytes = 512 * (1024**2)
6473
else:
6574
# Round up to the nearest multiple of 8MB
6675
buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2))
@@ -108,7 +117,28 @@ def __init__(
108117
self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank)
109118
self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr()
110119

120+
@functools.cache
121+
def is_buffer_size_sufficient(
122+
self,
123+
tp_size: int,
124+
num_tokens: int,
125+
hidden_dim: int,
126+
dtype: torch.dtype,
127+
strategy: MNNVLAllreduceFusionStrategy = MNNVLAllreduceFusionStrategy.AUTO,
128+
) -> bool:
129+
"""
130+
Calculate the required buffer size for a given problem size.
131+
"""
132+
required_buffer_size = self.get_required_buffer_size_bytes(
133+
tp_size, num_tokens, hidden_dim, dtype, strategy
134+
)
135+
if required_buffer_size > self.buffer_size_bytes:
136+
return False
137+
else:
138+
return True
139+
111140
@staticmethod
141+
@functools.cache
112142
def get_required_buffer_size_bytes(
113143
tp_size: int,
114144
num_tokens: int,
@@ -120,10 +150,19 @@ def get_required_buffer_size_bytes(
120150
Calculate the required buffer size for a given problem size.
121151
"""
122152
elem_size = torch.tensor([], dtype=dtype).element_size()
153+
<<<<<<< HEAD
123154
is_one_shot = MNNVLAllreduceFusionStrategy.is_one_shot(tp_size, num_tokens, hidden_dim, dtype)
124155
if strategy == MNNVLAllreduceFusionStrategy.ONESHOT or (
125156
strategy == MNNVLAllreduceFusionStrategy.AUTO and is_one_shot
126157
):
158+
=======
159+
if strategy == MNNVLAllreduceFusionStrategy.AUTO:
160+
strategy = MNNVLAllreduceFusionStrategy.select_strategy(
161+
tp_size, num_tokens, hidden_dim, dtype
162+
)
163+
164+
if strategy == MNNVLAllreduceFusionStrategy.ONESHOT:
165+
>>>>>>> c6ed1472 (Address review comments.)
127166
# For one-shot, each rank needs to store num_tokens * tp_size tokens
128167
buffer_size = num_tokens * hidden_dim * tp_size * elem_size
129168
else:
@@ -256,10 +295,25 @@ def trtllm_mnnvl_allreduce(
256295

257296
module = get_trtllm_mnnvl_comm_module()
258297

298+
<<<<<<< HEAD
259299
use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or (
260300
strategy == MNNVLAllreduceFusionStrategy.AUTO
261301
and MNNVLAllreduceFusionStrategy.is_one_shot(workspace.tp_size, input.shape[0], input.shape[1], input.dtype)
262302
)
303+
=======
304+
if strategy == MNNVLAllreduceFusionStrategy.AUTO:
305+
strategy = MNNVLAllreduceFusionStrategy.select_strategy(
306+
workspace.tp_size, input.shape[0], input.shape[1], input.dtype
307+
)
308+
309+
if not workspace.is_buffer_size_sufficient(
310+
workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy
311+
):
312+
raise ValueError(
313+
f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes."
314+
)
315+
316+
>>>>>>> c6ed1472 (Address review comments.)
263317
module.trtllm_mnnvl_allreduce_fusion(
264318
input,
265319
workspace.mc_ptr,
@@ -270,7 +324,7 @@ def trtllm_mnnvl_allreduce(
270324
workspace.rank,
271325
False, # No RMSNorm Fusion
272326
launch_with_pdl,
273-
use_oneshot,
327+
strategy == MNNVLAllreduceFusionStrategy.ONESHOT,
274328
output,
275329
None,
276330
None,
@@ -340,15 +394,16 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm(
340394

341395
module = get_trtllm_mnnvl_comm_module()
342396

343-
use_oneshot = strategy == MNNVLAllreduceFusionStrategy.ONESHOT or (
344-
strategy == MNNVLAllreduceFusionStrategy.AUTO
345-
and MNNVLAllreduceFusionStrategy.is_one_shot(
346-
workspace.tp_size,
347-
input.shape[0],
348-
input.shape[1],
349-
input.dtype,
397+
if strategy == MNNVLAllreduceFusionStrategy.AUTO:
398+
strategy = MNNVLAllreduceFusionStrategy.select_strategy(
399+
workspace.tp_size, input.shape[0], input.shape[1], input.dtype
400+
)
401+
if not workspace.is_buffer_size_sufficient(
402+
workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy
403+
):
404+
raise ValueError(
405+
f"The buffer size in the given workspace is insufficient for the given problem size. Buffer: {workspace.buffer_size_bytes} bytes, Required: {workspace.get_required_buffer_size_bytes(workspace.tp_size, input.shape[0], input.shape[1], input.dtype, strategy)} bytes."
350406
)
351-
)
352407

353408
module.trtllm_mnnvl_allreduce_fusion(
354409
input,
@@ -360,7 +415,7 @@ def trtllm_mnnvl_fused_allreduce_add_rmsnorm(
360415
workspace.rank,
361416
True, # RMSNorm Fusion
362417
launch_with_pdl,
363-
use_oneshot,
418+
strategy == MNNVLAllreduceFusionStrategy.ONESHOT,
364419
output,
365420
residual_out,
366421
residual_in,

0 commit comments

Comments
 (0)