Skip to content

Commit 92cbd48

Browse files
committed
Address review comments.
1 parent a2670e8 commit 92cbd48

File tree

3 files changed

+2
-7
lines changed

3 files changed

+2
-7
lines changed

csrc/trtllm_mnnvl_allreduce.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ void trtllm_mnnvl_allreduce_fusion(TensorView input, int64_t multicast_buffer_pt
103103
status = twoshotAllreduceFusionDispatch<c_type>(params);
104104
}
105105
TVM_FFI_ICHECK(status == cudaSuccess)
106-
<< "twoshot_allreduce_dispatch_world_size failed with error code "
107-
<< cudaGetErrorString(status);
106+
<< "trtllm_mnnvl_allreduce_fusion failed with error code " << cudaGetErrorString(status);
108107
});
109108
}
110109

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,6 @@ def trtllm_mnnvl_allreduce_fusion(
200200
gamma: Gamma tensor (if rmsnorm)
201201
epsilon: Epsilon value (if rmsnorm)
202202
"""
203-
print(
204-
f"[Rank {rank}] Inside Kernel: multicast_buffer_ptr: {multicast_buffer_ptr:x}, buffer_ptrs_dev: {buffer_ptrs_dev:x}, buffer_ptr_local: {buffer_ptr_local:x}, buffer_flags_mnnvl: {buffer_flags_mnnvl}"
205-
)
206203
module.trtllm_mnnvl_allreduce_fusion(
207204
input,
208205
multicast_buffer_ptr,

tests/comm/test_trtllm_mnnvl_allreduce.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Check torch version:
2+
import traceback
23
from typing import Tuple
34

45
import pytest
@@ -406,8 +407,6 @@ def run_mnnvl_ar_full(
406407
rank_failed = True
407408
failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}"
408409
print(failure_message)
409-
import traceback
410-
411410
print(traceback.format_exc())
412411

413412
# Gather failure status from all ranks for logging

0 commit comments

Comments
 (0)