Skip to content

Commit 2e048d1

Browse files
committed
WIP: Fix OOB in cache_kernels.cu (Issue #27909)
1 parent 8977ffb commit 2e048d1

File tree

2 files changed

+92
-9
lines changed

2 files changed

+92
-9
lines changed

csrc/cache_kernels.cu

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -966,22 +966,26 @@ __global__ void gather_and_maybe_dequant_cache(
966966
};
967967

968968
for (int pid = split_start; pid < full_blocks_end; ++pid) {
969-
auto block_id = batch_block_table[pid];
970-
auto block_start_ptr = src_cache + block_id * cache_block_stride;
971-
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
972-
for (int eid = 0; eid < block_size; ++eid) {
973-
copy_entry(block_start_ptr + eid * cache_entry_stride,
969+
if (offset + pid < block_table_stride){
970+
auto block_id = batch_block_table[pid];
971+
auto block_start_ptr = src_cache + block_id * cache_block_stride;
972+
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
973+
for (int eid = 0; eid < block_size; ++eid) {
974+
copy_entry(block_start_ptr + eid * cache_entry_stride,
974975
block_dst_ptr + eid * dst_entry_stride);
976+
}
975977
}
976978
}
977979

978980
if (partial_block_size) {
979-
auto block_id = batch_block_table[full_blocks_end];
980-
auto block_start_ptr = src_cache + block_id * cache_block_stride;
981-
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
982-
for (int eid = 0; eid < partial_block_size; ++eid) {
981+
if (offset + full_blocks_end < block_table_stride) {
982+
auto block_id = batch_block_table[full_blocks_end];
983+
auto block_start_ptr = src_cache + block_id * cache_block_stride;
984+
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
985+
for (int eid = 0; eid < partial_block_size; ++eid) {
983986
copy_entry(block_start_ptr + eid * cache_entry_stride,
984987
block_dst_ptr + eid * dst_entry_stride);
988+
}
985989
}
986990
}
987991
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Unit tests for CUDA kernels in cache_kernels.cu."""
4+
5+
import torch
6+
import pytest
7+
8+
try:
9+
from vllm import cache_ops
10+
except ImportError:
11+
try:
12+
from vllm.ops import cache_ops
13+
except ImportError:
14+
pytest.skip("Could not import vllm cache_ops. Skipping test.",
15+
allow_module_level=True)
16+
17+
18+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device")
19+
def test_gather_cache_oob_issue_27909():
20+
"""
21+
Tests for OOB read in gather_and_maybe_dequant_cache (Issue #27909).
22+
This test constructs a boundary case identified in the issue where
23+
seq_starts causes the block_table offset to read out of bounds.
24+
"""
25+
26+
batch_size = 1
27+
block_size = 64
28+
entry_size = 128
29+
30+
block_table = torch.tensor(
31+
[[1, 2]],
32+
dtype=torch.int32,
33+
device="cuda"
34+
)
35+
36+
#This will result in offset = 128 / block_size = 128 / 64 = 2
37+
# This will cause the kernel to try to read from block_table[0, 2], but its size is only 2.
38+
seq_starts = torch.tensor([128], dtype=torch.int32, device="cuda")
39+
40+
seq_len = 1
41+
cu_seq_lens = torch.tensor(
42+
[0, seq_len], # BATCH+1 = [0, 1]
43+
dtype=torch.int32,
44+
device="cuda"
45+
)
46+
47+
# src_cache: [num_blocks, block_size, entry_size]
48+
num_blocks = 5
49+
src_cache = torch.randn(
50+
(num_blocks, block_size, entry_size),
51+
dtype=torch.float16,
52+
device="cuda"
53+
)
54+
55+
dst = torch.empty(
56+
(seq_len, entry_size),
57+
dtype=torch.float16,
58+
device="cuda"
59+
)
60+
61+
scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
62+
63+
# Calling the C++ function gather_and_maybe_dequant_cache
64+
cache_ops.gather_and_maybe_dequant_cache(
65+
src_cache,
66+
dst,
67+
block_table,
68+
cu_seq_lens,
69+
batch_size,
70+
"auto", # kv_cache_dtype
71+
scale,
72+
seq_starts
73+
)
74+
75+
torch.cuda.synchronize()
76+
assert True
77+
78+
if __name__ == "__main__":
79+
pytest.main([__file__])

0 commit comments

Comments
 (0)