Skip to content

Commit 815aaf3

Browse files
committed
Rounding up workspace size according to allocation (page size).
1 parent 45a5b82 commit 815aaf3

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,14 @@ def get_world_size(self) -> int:
803803
"""Get the total number of devices in the group"""
804804
return self.group_size
805805

806+
def get_allocation_size(self) -> int:
807+
"""Get the total allocation size (including signal pad)"""
808+
return self.allocation_size
809+
810+
def get_usable_buffer_size(self) -> int:
811+
"""Get the usable buffer size (excluding signal pad)"""
812+
return self.allocation_size - self.SIGNAL_PAD_SIZE
813+
806814
def _init_ipc_socket(self):
807815
if self.group_rank == 0:
808816
# Gnerate the opId
@@ -838,7 +846,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
838846
alloc_granularity = checkCudaErrors(
839847
cuda.cuMemGetAllocationGranularity(
840848
allocation_prop,
841-
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM,
849+
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED,
842850
)
843851
)
844852

@@ -1015,8 +1023,8 @@ def lamport_initialize(self, rank: int, dtype: torch.dtype):
10151023
else:
10161024
raise ValueError(f"Unsupported dtype: {dtype}")
10171025

1018-
# Calculate number of elements that fit in allocation_size
1019-
num_elements = self.allocation_size // dsize
1026+
# Calculate number of elements that fit in allocation_size; We don't want to include the signal pad.
1027+
num_elements = (self.allocation_size - self.SIGNAL_PAD_SIZE) // dsize
10201028

10211029
checkCudaErrors(memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements))
10221030

@@ -1042,7 +1050,7 @@ def __init__(
10421050
Constructor for McastGpuBuffer.
10431051
10441052
Args:
1045-
buf_size: The total size of the buffer in bytes
1053+
buf_size: The requested size of the buffer in bytes. The actual usable size may differ due to alignment requirements.
10461054
group_size: The number of ranks in the communication group
10471055
group_rank: The rank of the local process within the group
10481056
device: The CUDA device for buffer allocation
@@ -1061,7 +1069,8 @@ def __init__(
10611069
mn_nvlink,
10621070
comm_backend_for_handle_transfer,
10631071
)
1064-
self.buf_size = buf_size
1072+
# Update buf_size to reflect the actual usable buffer size after allocation
1073+
self.buf_size = self.mcast_device_memory.get_usable_buffer_size()
10651074
self.local_device = device
10661075

10671076
def lamport_initialize(self, rank: int, dtype: torch.dtype):

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@ def __init__(
6565
6666
Args:
6767
mapping: Mapping configuration containing rank info
68-
buffer_size_in_bytes: The size in bytes for each lamport buffer. The actual allocation size will be NUM_LAMPORT_BUFFERS * buffer_size_in_bytes.
68+
buffer_size_in_bytes: The requested size in bytes for each lamport buffer. The actual allocation size may be larger due to alignment requirements. The actual usable size will be NUM_LAMPORT_BUFFERS * actual_buffer_size_per_lamport_buffer.
6969
"""
7070
if buffer_size_in_bytes is None:
71-
# Default to 512MB workspace size if not provided
72-
buffer_size_in_bytes = 512 * (1024**2)
71+
# Default to 16MB workspace size if not provided
72+
buffer_size_in_bytes = 16 * (1024**2)
7373
else:
7474
# Round up to the nearest multiple of 8MB
7575
buffer_size_in_bytes = math.ceil(buffer_size_in_bytes / (8 * (1024**2))) * (8 * (1024**2))
@@ -80,22 +80,38 @@ def __init__(
8080
f"The buffer size in bytes {buffer_size_in_bytes} is greater than the maximum supported size (UINT32_MAX)."
8181
)
8282

83-
self.buffer_size_bytes = buffer_size_in_bytes
84-
self.workspace_size_bytes = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS
83+
# Calculate total requested workspace size
84+
requested_workspace_size = buffer_size_in_bytes * self.NUM_LAMPORT_BUFFERS
85+
8586
self.rank = mapping.tp_rank
8687
self.tp_size = mapping.tp_size
8788
logging.debug(
88-
f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with size {buffer_size_in_bytes} bytes."
89+
f"[MNNVL Allreduce] TP size: {mapping.tp_size}, rank: {mapping.tp_rank}, Allocating workspace with requested size {buffer_size_in_bytes} bytes per buffer."
8990
)
91+
92+
# Allocate the workspace
9093
self.mcast_buffer_handle = McastGPUBuffer(
91-
self.workspace_size_bytes,
94+
requested_workspace_size,
9295
mapping.tp_size,
9396
mapping.tp_rank,
9497
torch.device("cuda", mapping.local_rank),
9598
mapping.is_multi_node(),
9699
comm_backend,
97100
)
98101

102+
# Get the actual usable buffer size after allocation (buf_size is updated by McastGPUBuffer)
103+
allocated_size = self.mcast_buffer_handle.buf_size
104+
# We want the buffer size to be aligned to 16B which is the granularity for buffer management.
105+
self.buffer_size_bytes = (
106+
math.floor(allocated_size / self.NUM_LAMPORT_BUFFERS) // 16 * 16
107+
)
108+
# This workspace size is used for checking the buffer. We need to set it to the actual size in use. The buffer free logic does not rely on this size.
109+
self.workspace_size_bytes = self.buffer_size_bytes * self.NUM_LAMPORT_BUFFERS
110+
111+
logging.debug(
112+
f"[MNNVL Allreduce] Actual allocated size: {allocated_size} bytes, Actual buffer size per lamport buffer: {self.buffer_size_bytes} bytes, total workspace: {self.workspace_size_bytes} bytes."
113+
)
114+
99115
# We use FP32 for sentinel value regardless of the real dtype
100116
self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32)
101117
# Wait until the initialization is done

0 commit comments

Comments
 (0)