From a74fa01bb25a8c5a6c666039017ebbe7724eb48b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 2 Dec 2025 03:40:28 +0000 Subject: [PATCH 1/3] proposal Signed-off-by: Lucas Wilkinson --- _posts/2025-11-27-improved-cuda-debugging.md | 153 ++++++++++++++++--- 1 file changed, 128 insertions(+), 25 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index bc2b9e4..88e7889 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -278,45 +278,148 @@ $ grep -C20 7ff533bb91d0 output.txt The main difference is obtaining the CUDA function index (the `-fun` argument) from `cuobjdump` by searching the function's ELF section, which is `26a` in this case. -Note that this is a simplified example to demonstrate the technique. Real-world kernels can be much more complex. For example, here is a complex inline case: +Note that this is a simplified example to demonstrate the technique. Real-world kernels can be even more complex. We can see this with a slightly more complicated vLLM specific example using a CUTLASS GEMM kernel integrated into vLLM (NOTE: this example is intended for an sm90 Hopper device): +```python +# save as illegal_memory_access.py + +from dataclasses import dataclass +import torch + +@dataclass +class TensorWrapper: + data_ptr: int + size_in_bytes: int + dtype_str: str = '|u1' + + @property + def __cuda_array_interface__(self): + return { "shape": (self.size_in_bytes,), "typestr": self.dtype_str, "data": (self.data_ptr, False), "version": 3 } + + +def from_buffer(data_ptr: int, size_in_bytes: int, device: str, + dtype: torch.dtype) -> torch.Tensor: + return torch.as_tensor(TensorWrapper(data_ptr, size_in_bytes), device=device).view(dtype) + + +import vllm._custom_ops as ops + +M, K, N = 128, 256, 256 +b = torch.randn(K, N, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn).t().contiguous().t() +a_scales = torch.ones(1, device="cuda", dtype=torch.float32) +b_scales = torch.ones(1, device="cuda", dtype=torch.float32) + +# Create tensor 'a' with an INVALID data pointer (will cause illegal memory access) +invalid_ptr = 0x123456 +a_size_bytes = M * K # FP8 is 1 byte per element +a = from_buffer(invalid_ptr, a_size_bytes, device="cuda:0", dtype=torch.float8_e4m3fn) +a = a.view(M, K) + +# This will trigger an illegal memory access when CUTLASS tries to read from 'a' +result = ops.cutlass_scaled_mm( + a=a, + b=b, + scale_a=a_scales, + scale_b=b_scales, + out_dtype=torch.bfloat16, +) + +print(result) +``` + +Following the same steps as before we first rebuild vLLM with lineinfo; If vLLM was installed via an editable install (i.e. `-e .`) this can be done using: + +```bash +NVCC_PREPEND_FLAGS="-lineinfo" python setup.py build_ext --inplace +``` + +Then run the code with CUDA core dump enabled: + +```bash +CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \ +CUDA_COREDUMP_SHOW_PROGRESS=1 \ +CUDA_COREDUMP_GENERATION_FLAGS='skip_nonrelocated_elf_images,skip_global_memory,skip_shared_memory,skip_local_memory,skip_constbank_memory' \ +CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" \ +python illegal_memory_access.py +``` ```text - //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/copy_sm90.hpp", line 93 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158 - //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 185 - //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 185 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_traits.hpp", line 133 - //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_traits.hpp", line 133 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 103 - //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 103 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 124 - //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 124 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 211 - //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 211 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 412 - //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 412 inlined at "/data/youkaichao/data/vllm_flash_attn/hopper/epilogue_fwd.hpp", line 265 - //## File "/data/youkaichao/data/vllm_flash_attn/hopper/epilogue_fwd.hpp", line 265 inlined at "/data/youkaichao/data/vllm_flash_attn/hopper/flash_fwd_kernel_sm90.h", line 454 - //## File "/data/youkaichao/data/vllm_flash_attn/hopper/flash_fwd_kernel_sm90.h", line 454 inlined at "/data/youkaichao/data/vllm_flash_attn/hopper/utils.h", line 41 - //## File "/data/youkaichao/data/vllm_flash_attn/hopper/utils.h", line 41 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cutlass/device_kernel.h", line 122 - //## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cutlass/device_kernel.h", line 122 - /*7eebf5e9eb80*/ STSM.16.M88.4 [R13], R4 ; - /*7eebf5e9eb90*/ MOV R34, R26 ; +(cuda-gdb) target cudacore /tmp/cuda_coredump_nm-automation-h100-standalone-0-preserve.361991.1764626086 +Opening GPU coredump: /tmp/cuda_coredump_nm-automation-h100-standalone-0-preserve.361991.1764626086 +[Current focus set to CUDA kernel 0, grid 6, cluster (0,1,0), block (0,1,0), thread (0,0,0), device 0, sm 124, warp 2, lane 0] + +CUDA Exception: Warp Illegal Instruction +The exception was triggered at PC 0x7f5687bbb580 void cutlass::device_kernel, cute::C<128>, cute::C<128> >, cute::tuple, cute::C<1>, cute::C<1> >, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::epilogue::TmaWarpSpecialized, false>::GemmKernel>(vllm::cutlass_3x_gemm_sm90_fp8, cute::C<128>, cute::C<128> >, cute::tuple, cute::C<1>, cute::C<1> >, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::epilogue::TmaWarpSpecialized, false>::GemmKernel::Params) (copy_sm90_tma.hpp:185 in _ZN4cute16SM90_TMA_LOAD_3D4copyEPKvPmmPvRKiS6_S6_ inlined from copy_sm90_tma.hpp:348) +#0 cutlass::device_kernel, cute::C<128>, cute::C<128> >, cute::tuple, cute::C<1>, cute::C<1> >, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::epilogue::TmaWarpSpecialized, false>::GemmKernel><<<(2,2,1),(384,1,1)>>> () + at /usr/local/cuda-12.9/include/sm_20_intrinsics.hpp:151 in _ZN52_INTERNAL_64778a7b_21_scaled_mm_sm90_fp8_cu_e01c669e24__cvta_generic_to_sharedEPKv inlined from util.hpp:108 ``` -In this case, the problematic code is: +From the kernel name, we can see that the issue is caused by vLLM's CUTLASS FP8 GEMM kernel (`cutlass_3x_gemm_sm90_fp8`). This is a heavily templated kernel with deep inlining—exactly the scenario where standard debugging falls short. For example if we use `info line *$errorpc`: + +```text +(cuda-gdb) info line *$errorpc +Line 185 of "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/arch/copy_sm90_tma.hpp" +``` + +This leads us to: +```c++ + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +``` -

- - -
-A line of poisoned code in the attention kernel. -

+Unfortunately this is not very useful since CUTLASS GEMM implementations issue many TMA operations for various operands (e.g., matrices A, B, C, scales for A, etc.). Instead let's follow the steps laid out above. First use `info symbol $errorpc` to get more information about the error location: -The faulty source code calls some CUTLASS functions, and the function containing it also gets inlined by an upper-level caller. In this case, `cuda-gdb` cannot correctly associate the line. In fact, it does not show any line information around the error location. Even when it shows the correct line, it only displays the last inline frame, which is `File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/copy_sm90.hpp", line 93 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158`—an internal inline expansion of the CUTLASS function that is still unhelpful for debugging the underlying issue. +``` +(cuda-gdb) info symbol $errorpc +void cutlass::device_kernel, cute::C<128>, cute::C<128> >, cute::tuple, cute::C<1>, cute::C<1> >, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::epilogue::TmaWarpSpecialized, false>::GemmKernel>(vllm::cutlass_3x_gemm_sm90_fp8, cute::C<128>, cute::C<128> >, cute::tuple, cute::C<1>, cute::C<1> >, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::epilogue::TmaWarpSpecialized, false>::GemmKernel::Params) + 16256 in section .text._ZN7cutlass13device_kernelIN4vllm24cutlass_3x_gemm_sm90_fp8INS_12float_e4m3_tENS_10bfloat16_tENS1_3c3x14ScaledEpilogueEN4cute5tupleIJNS7_1CILi64EEENS9_ILi128EEESB_EEENS8_IJNS9_ILi2EEENS9_ILi1EEESE_EEENS_4gemm44KernelTmaWarpSpecializedPingpongFP8FastAccumENS_8epilogue18TmaWarpSpecializedELb0EE10GemmKernelEEEvNT_6ParamsE of /tmp/cuda-dbg/439034/session1/elf.55caf0a395a0.55caf38fedc0.o.XNSLjS +``` + +Then disassemble with line info using nvdisasm: + +```bash +$ nvdisasm -ndf -c -gi /tmp/cuda-dbg/1720662/session1/elf.55fde9670830.55fdec5c6b40.o.gKwLOj > output.txt +$ grep -C20 "7f5687bbb580" output.txt +``` + +This reveals a deep inline call chain: + +``` + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/arch/copy_sm90_tma.hpp", line 185 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/arch/copy_sm90_tma.hpp", line 348 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/arch/copy_sm90_tma.hpp", line 348 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/arch/util.hpp", line 158 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/arch/util.hpp", line 158 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/arch/util.hpp", line 315 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/arch/util.hpp", line 315 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/atom/copy_traits_sm90_tma.hpp", line 82 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/atom/copy_traits_sm90_tma.hpp", line 82 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/atom/copy_atom.hpp", line 103 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/atom/copy_atom.hpp", line 103 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/atom/copy_atom.hpp", line 124 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/atom/copy_atom.hpp", line 124 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/algorithm/copy.hpp", line 226 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/algorithm/copy.hpp", line 226 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/algorithm/copy.hpp", line 545 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cute/algorithm/copy.hpp", line 545 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp", line 384 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp", line 384 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp", line 643 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp", line 643 inlined at "/home/LucasWilkinson/code/vllm/csrc/cutlass_extensions/common.hpp", line 39 + //## File "/home/LucasWilkinson/code/vllm/csrc/cutlass_extensions/common.hpp", line 39 inlined at "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cutlass/device_kernel.h", line 123 + //## File "/home/LucasWilkinson/code/vllm/.deps/cutlass-src/include/cutlass/device_kernel.h", line 123 + /*7f5687bbb580*/ UTMALDG.3D [UR8], [UR14], desc[UR16] ; +``` + +Now we can trace the issue back through the full call chain — from ptx instruction we saw before all the way up to where it is instantiated in vLLM. Following the call chain we can get to a contextually useful line, in this case that is in CUTLASS's collective mainloop (`sm90_mma_tma_gmma_ss_warpspecialized.hpp`): +```c++ +copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); +``` +source: https://github.com/NVIDIA/cutlass/blob/f3fde58372d33e9a5650ba7b80fc48b3b49d40c8/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp#L384 -With the approach outlined above, we can uncover the full inline chain of the source code and carefully examine each frame to identify which line is responsible for the error. +This is more helpful as it informs us the issue is with loading the A matrix specifically, which makes sense since we corrupted the pointer of the A matrix. **Warning:** To maximize the benefit of CUDA core dumps, line information is crucial. It is recommended to compile with the `export NVCC_PREPEND_FLAGS='-lineinfo'` environment variable, as this transparently applies to all compiled kernels without needing to modify compilation scripts. However, this transparency means that if you use a compilation caching mechanism such as `ccache`, it may ignore the flag and reuse previously compiled results without actual compilation. When compiling from source, ensure that the compilation caching mechanism is disabled. If you use Just-In-Time compilation, please consult the documentation of your Just-In-Time compilation tool to see how to add line information. ## Conclusion This blog post introduced two advanced debugging techniques for CUDA kernels. The first technique uses user-triggered core dumps to identify hanging kernels, while the second traces complex kernels back to their source code by leveraging line information embedded in the compiled binary. These techniques are powerful tools for debugging complex issues in CUDA kernels, especially illegal memory access problems. +Using both the `user induced GPU core dump generation` and `nvdisasm` techniques we were able to recently debug a hard-to-reproduce and tricky hang in the CUTLASS MLA attention backend: https://github.com/vllm-project/vllm/pull/26026 (this bug actually stemmed from the upstream CUTLASS code example and has since been fixed in [v4.3.0](https://github.com/NVIDIA/cutlass/commit/b1d6e2c9b334dfa811e4183dfbd02419249e4b52)). -The vLLM project aims to provide easy, fast, and affordable LLM serving for everyone, and accessible debugging is an important aspect of this mission. We will continue to share more debugging tips and techniques in the future to build a strong LLM inference ecosystem together. To share your story or usage with vLLM, please submit a PR at [the blogpost repository](https://github.com/vllm-project/vllm-project.github.io). +The vLLM project aims to provide easy, fast, stable, and affordable LLM serving for everyone, and accessible debugging is an important aspect of this mission. We will continue to share more debugging tips and techniques in the future to build a strong LLM inference ecosystem together. To share your story or usage with vLLM, please submit a PR at [the blogpost repository](https://github.com/vllm-project/vllm-project.github.io). # Acknowledgement From 034828cc2a50ef64e0cda4681c2e216372ffc004 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 2 Dec 2025 03:48:35 +0000 Subject: [PATCH 2/3] cleanup Signed-off-by: Lucas Wilkinson --- _posts/2025-11-27-improved-cuda-debugging.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index 88e7889..0125e51 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -327,7 +327,7 @@ result = ops.cutlass_scaled_mm( print(result) ``` -Following the same steps as before we first rebuild vLLM with lineinfo; If vLLM was installed via an editable install (i.e. `-e .`) this can be done using: +Following the same steps as before we first rebuild vLLM with lineinfo; if vLLM was installed via an editable install (i.e. `-e .`) this can be done using: ```bash NVCC_PREPEND_FLAGS="-lineinfo" python setup.py build_ext --inplace @@ -404,7 +404,7 @@ This reveals a deep inline call chain: /*7f5687bbb580*/ UTMALDG.3D [UR8], [UR14], desc[UR16] ; ``` -Now we can trace the issue back through the full call chain — from ptx instruction we saw before all the way up to where it is instantiated in vLLM. Following the call chain we can get to a contextually useful line, in this case that is in CUTLASS's collective mainloop (`sm90_mma_tma_gmma_ss_warpspecialized.hpp`): +Now we can trace the issue back through the full call chain — from ptx instruction we saw before all the way to the device_kernel entry point. Following the call chain we can get to a contextually useful line, in this case that is in CUTLASS's collective mainloop (`sm90_mma_tma_gmma_ss_warpspecialized.hpp`): ```c++ copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); ``` @@ -416,8 +416,7 @@ This is more helpful as it informs us the issue is with loading the A matrix spe ## Conclusion -This blog post introduced two advanced debugging techniques for CUDA kernels. The first technique uses user-triggered core dumps to identify hanging kernels, while the second traces complex kernels back to their source code by leveraging line information embedded in the compiled binary. These techniques are powerful tools for debugging complex issues in CUDA kernels, especially illegal memory access problems. -Using both the `user induced GPU core dump generation` and `nvdisasm` techniques we were able to recently debug a hard-to-reproduce and tricky hang in the CUTLASS MLA attention backend: https://github.com/vllm-project/vllm/pull/26026 (this bug actually stemmed from the upstream CUTLASS code example and has since been fixed in [v4.3.0](https://github.com/NVIDIA/cutlass/commit/b1d6e2c9b334dfa811e4183dfbd02419249e4b52)). +This blog post introduced two advanced debugging techniques for CUDA kernels. The first technique uses user-triggered core dumps to identify hanging kernels, while the second traces complex kernels back to their source code by leveraging line information embedded in the compiled binary. These techniques are powerful tools for debugging complex issues in CUDA kernels, especially illegal memory access problems. Using both in tandem we were able to recently debug a hard-to-reproduce and tricky hang in the CUTLASS MLA attention backend: https://github.com/vllm-project/vllm/pull/26026 (this bug actually stemmed from the upstream CUTLASS code example and has since been fixed in [v4.3.0](https://github.com/NVIDIA/cutlass/commit/b1d6e2c9b334dfa811e4183dfbd02419249e4b52)). The vLLM project aims to provide easy, fast, stable, and affordable LLM serving for everyone, and accessible debugging is an important aspect of this mission. We will continue to share more debugging tips and techniques in the future to build a strong LLM inference ecosystem together. To share your story or usage with vLLM, please submit a PR at [the blogpost repository](https://github.com/vllm-project/vllm-project.github.io). From ff618dee848c8a5980a7d0e968eaaaef4aaf54d1 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 2 Dec 2025 03:52:25 +0000 Subject: [PATCH 3/3] minor fix Signed-off-by: Lucas Wilkinson --- _posts/2025-11-27-improved-cuda-debugging.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_posts/2025-11-27-improved-cuda-debugging.md b/_posts/2025-11-27-improved-cuda-debugging.md index 0125e51..df0cce7 100644 --- a/_posts/2025-11-27-improved-cuda-debugging.md +++ b/_posts/2025-11-27-improved-cuda-debugging.md @@ -327,7 +327,7 @@ result = ops.cutlass_scaled_mm( print(result) ``` -Following the same steps as before we first rebuild vLLM with lineinfo; if vLLM was installed via an editable install (i.e. `-e .`) this can be done using: +Following the same steps as before we first rebuild vLLM with lineinfo; if vLLM was installed via an editable install (i.e. `-e . --no-build-isolation`) this can be done using: ```bash NVCC_PREPEND_FLAGS="-lineinfo" python setup.py build_ext --inplace