1+ # SPDX-License-Identifier: Apache-2.0
2+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ """Unit tests for CUDA kernels in cache_kernels.cu."""
4+
5+ import torch
6+ import pytest
7+
8+ try :
9+ from vllm import cache_ops
10+ except ImportError :
11+ try :
12+ from vllm .ops import cache_ops
13+ except ImportError :
14+ pytest .skip ("Could not import vllm cache_ops. Skipping test." ,
15+ allow_module_level = True )
16+
17+
18+ @pytest .mark .skipif (torch .cuda .device_count () < 1 , reason = "Need CUDA device" )
19+ def test_gather_cache_oob_issue_27909 ():
20+ """
21+ Tests for OOB read in gather_and_maybe_dequant_cache (Issue #27909).
22+ This test constructs a boundary case identified in the issue where
23+ seq_starts causes the block_table offset to read out of bounds.
24+ """
25+
26+ batch_size = 1
27+ block_size = 64
28+ entry_size = 128
29+
30+ block_table = torch .tensor (
31+ [[1 , 2 ]],
32+ dtype = torch .int32 ,
33+ device = "cuda"
34+ )
35+
36+ #This will result in offset = 128 / block_size = 128 / 64 = 2
37+ # This will cause the kernel to try to read from block_table[0, 2], but its size is only 2.
38+ seq_starts = torch .tensor ([128 ], dtype = torch .int32 , device = "cuda" )
39+
40+ seq_len = 1
41+ cu_seq_lens = torch .tensor (
42+ [0 , seq_len ], # BATCH+1 = [0, 1]
43+ dtype = torch .int32 ,
44+ device = "cuda"
45+ )
46+
47+ # src_cache: [num_blocks, block_size, entry_size]
48+ num_blocks = 5
49+ src_cache = torch .randn (
50+ (num_blocks , block_size , entry_size ),
51+ dtype = torch .float16 ,
52+ device = "cuda"
53+ )
54+
55+ dst = torch .empty (
56+ (seq_len , entry_size ),
57+ dtype = torch .float16 ,
58+ device = "cuda"
59+ )
60+
61+ scale = torch .tensor ([1.0 ], dtype = torch .float32 , device = "cuda" )
62+
63+ # Calling the C++ function gather_and_maybe_dequant_cache
64+ cache_ops .gather_and_maybe_dequant_cache (
65+ src_cache ,
66+ dst ,
67+ block_table ,
68+ cu_seq_lens ,
69+ batch_size ,
70+ "auto" , # kv_cache_dtype
71+ scale ,
72+ seq_starts
73+ )
74+
75+ torch .cuda .synchronize ()
76+ assert True
77+
78+ if __name__ == "__main__" :
79+ pytest .main ([__file__ ])
0 commit comments