Skip to content

Commit 775918d

Browse files
committed
Address review comments.
1 parent 01564e9 commit 775918d

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

csrc/trtllm_mnnvl_allreduce.cu

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,18 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt
5353
<< "nranks must be between 2 and 64, got " << nranks;
5454
TVM_FFI_ICHECK(rank >= 0 && rank < nranks)
5555
<< "rank must be between 0 and nranks-1, got " << rank;
56-
TVM_FFI_ICHECK((residual_out.has_value() && gamma.has_value() && epsilon.has_value()) ||
56+
TVM_FFI_ICHECK((residual_in.has_value() && residual_out.has_value() && gamma.has_value() &&
57+
epsilon.has_value()) ||
5758
!rmsnorm_fusion)
58-
<< "residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is true";
59+
<< "residual_in, residual_out, gamma, and epsilon must be provided if rmsnorm_fusion is "
60+
"true";
5961

6062
if (rmsnorm_fusion) {
63+
TVM_FFI_ICHECK(residual_in.value().size(0) == num_tokens &&
64+
residual_in.value().size(1) == token_dim)
65+
<< "residual_in shape mismatch: expected (" << input.size(0) << ", " << input.size(1)
66+
<< ") but got (" << residual_in.value().size(0) << ", " << residual_in.value().size(1)
67+
<< ")";
6168
TVM_FFI_ICHECK(residual_out.value().size(0) == num_tokens &&
6269
residual_out.value().size(1) == token_dim)
6370
<< "residual_out shape mismatch: expected (" << input.size(0) << ", " << input.size(1)

flashinfer/comm/mnnvl.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,9 @@ def __del__(self):
716716
if not hasattr(self, "is_multi_node"):
717717
return
718718

719+
if hasattr(self, "_ipc_socket"):
720+
self._ipc_socket.close()
721+
719722
# Skip cleanup during Python finalization to avoid segfaults
720723
# Especially cause the CUDA context could be destroyed at this point.
721724
if sys.is_finalizing():
@@ -864,7 +867,7 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
864867
# Allocate local GPU memory
865868
self.uc_handles[self.group_rank] = checkCudaErrors(cuda.cuMemCreate(self.allocation_size, allocation_prop, 0))
866869

867-
# Export local handle to fabric handle
870+
# Export local handle to fabric handle or FD
868871
local_shareable_uc_handle = checkCudaErrors(
869872
cuda.cuMemExportToShareableHandle(
870873
self.uc_handles[self.group_rank],
@@ -898,6 +901,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
898901
self._shareable_handle_type,
899902
)
900903
)
904+
if (
905+
self._shareable_handle_type
906+
== cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
907+
):
908+
# Close FD after import
909+
os.close(all_shareable_uc_handles[p])
901910

902911
# Initialize multicasting
903912
if self.group_rank == 0:
@@ -943,7 +952,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
943952
self._shareable_handle_type,
944953
)
945954
)
946-
955+
if (
956+
self._shareable_handle_type
957+
== cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
958+
):
959+
# Close FD after import
960+
os.close(shareable_mc_handle)
947961
# Add device to multicast
948962
checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx))
949963

include/flashinfer/utils.cuh

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <cuda_fp8.h>
2222
#include <cuda_runtime.h>
2323

24+
#include <atomic>
2425
#include <cstdint>
2526
#include <iostream>
2627
#include <type_traits>
@@ -335,16 +336,20 @@ inline std::pair<int, int> GetCudaComputeCapability() {
335336
return std::make_pair(major, minor);
336337
}
337338

339+
// This function is thread-safe and cached the sm_count.
340+
// But it will only check the current CUDA device, thus assuming each process handles single GPU.
338341
inline int GetCudaMultiProcessorCount() {
339-
static int sm_count = 0;
340-
if (sm_count == 0) {
342+
static std::atomic<int> sm_count{0};
343+
int cached = sm_count.load(std::memory_order_relaxed);
344+
if (cached == 0) {
341345
int device_id;
342346
cudaGetDevice(&device_id);
343347
cudaDeviceProp device_prop;
344348
cudaGetDeviceProperties(&device_prop, device_id);
345-
sm_count = device_prop.multiProcessorCount;
349+
cached = device_prop.multiProcessorCount;
350+
sm_count.store(cached, std::memory_order_relaxed);
346351
}
347-
return sm_count;
352+
return cached;
348353
}
349354

350355
template <typename T>

tests/comm/test_trtllm_mnnvl_allreduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Check torch version:
22
import traceback
3-
from typing import Tuple
3+
from typing import Tuple, Optional
44

55
import pytest
66
import torch
@@ -286,7 +286,7 @@ def run_mnnvl_ar_full(
286286
fusion: bool,
287287
dtype: torch.dtype,
288288
hidden_size: int,
289-
legacy_explicit_workspace_bytes: int = None,
289+
legacy_explicit_workspace_bytes: Optional[int] = None,
290290
legacy_api: bool = False,
291291
):
292292
"""Core test logic for MNNVL AllReduce operations.

0 commit comments

Comments
 (0)