Skip to content

Commit fb8851f

Browse files
authored
[Bugfix][cache_kernels]: Fix OOB in cache_kernels.cu (#28760)
Signed-off-by: vensen <vensenmu@gmail.com> Signed-off-by: Vensenmu <vensenmu@gmail.com>
1 parent a903d59 commit fb8851f

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

csrc/cache_kernels.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,9 @@ __global__ void gather_and_maybe_dequant_cache(
965965
}
966966
};
967967

968-
for (int pid = split_start; pid < full_blocks_end; ++pid) {
968+
const auto loop_end =
969+
std::min((int64_t)full_blocks_end, block_table_stride - offset);
970+
for (int pid = split_start; pid < loop_end; ++pid) {
969971
auto block_id = batch_block_table[pid];
970972
auto block_start_ptr = src_cache + block_id * cache_block_stride;
971973
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
@@ -976,12 +978,15 @@ __global__ void gather_and_maybe_dequant_cache(
976978
}
977979

978980
if (partial_block_size) {
979-
auto block_id = batch_block_table[full_blocks_end];
980-
auto block_start_ptr = src_cache + block_id * cache_block_stride;
981-
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
982-
for (int eid = 0; eid < partial_block_size; ++eid) {
983-
copy_entry(block_start_ptr + eid * cache_entry_stride,
984-
block_dst_ptr + eid * dst_entry_stride);
981+
if (offset + full_blocks_end < block_table_stride) {
982+
auto block_id = batch_block_table[full_blocks_end];
983+
auto block_start_ptr = src_cache + block_id * cache_block_stride;
984+
auto block_dst_ptr =
985+
dst + full_blocks_end * block_size * dst_entry_stride;
986+
for (int eid = 0; eid < partial_block_size; ++eid) {
987+
copy_entry(block_start_ptr + eid * cache_entry_stride,
988+
block_dst_ptr + eid * dst_entry_stride);
989+
}
985990
}
986991
}
987992
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Unit tests for CUDA kernels in cache_kernels.cu."""
4+
5+
import pytest
6+
import torch
7+
8+
try:
9+
from vllm import _custom_ops as ops
10+
except ImportError:
11+
pytest.skip(
12+
"Could not import vllm._custom_ops. (pip install -e .)", allow_module_level=True
13+
)
14+
15+
16+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device")
17+
def test_gather_cache_oob():
18+
"""
19+
Tests for OOB read in gather_and_maybe_dequant_cache (Issue #27909).
20+
This test constructs a boundary case identified in the issue where
21+
seq_starts causes the block_table offset to read out of bounds.
22+
"""
23+
24+
batch_size = 1
25+
block_size = 64
26+
entry_size = 128
27+
28+
block_table = torch.tensor([[1, 2]], dtype=torch.int32, device="cuda")
29+
30+
# This will result in offset = 128 / block_size = 128 / 64 = 2
31+
# This will cause the kernel to try to read from
32+
# block_table[0, 2], but its size is only 2.
33+
seq_starts = torch.tensor([128], dtype=torch.int32, device="cuda")
34+
35+
seq_len = 65
36+
cu_seq_lens = torch.tensor([0, seq_len], dtype=torch.int32, device="cuda")
37+
38+
# src_cache: [num_blocks, block_size, entry_size]
39+
num_blocks = 5
40+
src_cache = torch.randn(
41+
(num_blocks, block_size, entry_size), dtype=torch.float16, device="cuda"
42+
)
43+
44+
dst = torch.empty((seq_len, entry_size), dtype=torch.float16, device="cuda")
45+
46+
scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
47+
48+
# Calling the C++ function gather_and_maybe_dequant_cache
49+
ops.gather_and_maybe_dequant_cache(
50+
src_cache,
51+
dst,
52+
block_table,
53+
cu_seq_lens,
54+
batch_size,
55+
"auto", # kv_cache_dtype
56+
scale,
57+
seq_starts,
58+
)
59+
60+
torch.cuda.synchronize()
61+
assert True
62+
63+
64+
if __name__ == "__main__":
65+
pytest.main([__file__])

0 commit comments

Comments
 (0)