Skip to content

Commit 8ee846c

Browse files
authored
[Bugfix] Re-enable prefill of max model length (vllm-project#24446)
Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
1 parent 812b7f5 commit 8ee846c

File tree

2 files changed

+113
-8
lines changed

2 files changed

+113
-8
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
end-to-end tests for context length corner cases of vLLM v1 model runner
5+
versus HuggingFace's transformers.
6+
7+
This test verifies the following behavior: allow a prefill that fills the
8+
model's maximum context length and then request a single new token.
9+
10+
Test strategy
11+
- Build a textual prompt that tokenizes to exactly ``max_model_len`` tokens.
12+
- Run vLLM generation requesting a single new token (max_tokens=1).
13+
- Run HF generation on the same prompt requesting a single token too.
14+
- Assert both return the same number of generated tokens and the same ids.
15+
16+
"""
17+
18+
import pytest
19+
import torch
20+
from transformers import AutoModelForCausalLM
21+
22+
from tests.models.utils import check_outputs_equal
23+
from tests.utils import create_new_process_for_each_test
24+
from vllm import LLM, SamplingParams
25+
from vllm.inputs import TokensPrompt
26+
27+
28+
@create_new_process_for_each_test()
29+
@pytest.mark.parametrize("model", ["JackFram/llama-160m"])
30+
@pytest.mark.parametrize("max_model_len", [2048])
31+
@pytest.mark.parametrize("max_tokens", [1])
32+
def test_prefill_max_context_length(
33+
model: str,
34+
max_model_len: int,
35+
max_tokens: int,
36+
) -> None:
37+
"""Compare vLLM and HuggingFace when the prompt already fills the
38+
model's maximum context length and we request a single new token.
39+
40+
The test ensures vLLM does not raise the "Sampled token IDs exceed the
41+
max model length" assertion and that both vLLM and HF produce the same
42+
single token when given the same inputs.
43+
"""
44+
45+
# Construct a prompt of size max_model_len
46+
prompt_ids = [[43] * max_model_len]
47+
48+
# Generate max_tokens new tokens deterministically.
49+
sampling_params = [
50+
SamplingParams(max_tokens=max_tokens, temperature=0.0, ignore_eos=True)
51+
]
52+
53+
# --- vLLM generation ---
54+
llm = LLM(
55+
model=model,
56+
tokenizer=model,
57+
max_num_seqs=1,
58+
tensor_parallel_size=1,
59+
)
60+
61+
vllm_token_prompts = [TokensPrompt(prompt_token_ids=prompt_ids[0])]
62+
vllm_results = llm.generate(vllm_token_prompts, sampling_params)
63+
64+
vllm_output_ids = vllm_results[0].outputs[0].token_ids
65+
66+
# --- HuggingFace generation ---
67+
with torch.no_grad():
68+
hf_model = AutoModelForCausalLM.from_pretrained(model)
69+
70+
# HF expects a tensor of input ids shaped (batch, seq_len).
71+
hf_input_tokens = torch.tensor(prompt_ids[0]).unsqueeze(0)
72+
73+
# Generate max_tokens new tokens deterministically.
74+
hf_generated = hf_model.generate(
75+
hf_input_tokens,
76+
do_sample=False,
77+
min_new_tokens=max_tokens,
78+
max_new_tokens=max_tokens,
79+
)
80+
81+
# HF returns the prompt + generated tokens. Slice off the prompt.
82+
hf_output_ids = hf_generated.cpu().tolist()[0][len(prompt_ids[0]):]
83+
84+
# check that vLLM outputs (token ids) match HF outputs
85+
# Note: for simplicity don't pass detokenized string
86+
check_outputs_equal(
87+
outputs_0_lst=[(hf_output_ids, "")],
88+
outputs_1_lst=[(vllm_output_ids, "")],
89+
name_0="hf",
90+
name_1="vllm",
91+
)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2247,14 +2247,28 @@ def _bookkeeping_sync(
22472247

22482248
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
22492249
end_idx = start_idx + len(sampled_ids)
2250-
assert end_idx <= self.max_model_len, (
2251-
"Sampled token IDs exceed the max model length. "
2252-
f"Total number of tokens: {end_idx} > max_model_len: "
2253-
f"{self.max_model_len}")
2254-
2255-
self.input_batch.token_ids_cpu[req_idx,
2256-
start_idx:end_idx] = sampled_ids
2257-
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
2250+
assert end_idx <= self.max_model_len + 1, (
2251+
"Sampled token IDs exceed the max model length + 1. "
2252+
f"Total number of tokens: {end_idx} > max_model_len + 1: "
2253+
f"{self.max_model_len + 1}")
2254+
2255+
n_tokens_cache = len(sampled_ids)
2256+
2257+
# Sampled token IDs exceed the max model length by 1. This is
2258+
# legitimate as we can still sample 1 last token when the context
2259+
# length equals the max model length. Note that we do not need to
2260+
# cache this token ID as the sequence finishes after this step.
2261+
# Additionally, the buffers token_ids_cpu and is_token_ids are of
2262+
# size max model length only.
2263+
if end_idx == self.max_model_len + 1:
2264+
n_tokens_cache -= 1
2265+
2266+
self.input_batch.token_ids_cpu[req_idx, start_idx:(
2267+
start_idx + n_tokens_cache)] = sampled_ids[:n_tokens_cache]
2268+
self.input_batch.is_token_ids[req_idx,
2269+
start_idx:(start_idx +
2270+
n_tokens_cache)] = True
2271+
22582272
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
22592273
self.input_batch.num_tokens[req_idx] = end_idx
22602274

0 commit comments

Comments
 (0)