Skip to content

Commit 5e4a0a1

Browse files
njhillDhruvilbhatt
authored andcommitted
[BugFix] Fix async scheduling + request preemption (vllm-project#26385)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
1 parent e46ad3e commit 5e4a0a1

File tree

2 files changed

+104
-3
lines changed

2 files changed

+104
-3
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Any
4+
5+
import pytest
6+
7+
from vllm import SamplingParams
8+
9+
from ...conftest import VllmRunner
10+
from ...models.utils import check_outputs_equal
11+
12+
MODEL = "Qwen/Qwen3-0.6B"
13+
14+
15+
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
16+
"""Test consistency of combos of async scheduling, preemption,
17+
uni/multiproc executor, and various sampling parameters."""
18+
19+
first_prompt = (
20+
"The following numbers of the sequence "
21+
+ ", ".join(str(i) for i in range(10))
22+
+ " are:"
23+
)
24+
example_prompts = [first_prompt, "In one word, the capital of France is "] + [
25+
f"Tell me about the number {i}: " for i in range(32)
26+
]
27+
28+
sampling_param_tests: list[dict[str, Any]] = [
29+
dict(),
30+
# dict(min_tokens=20),
31+
# TODO enable these with https://github.com/vllm-project/vllm/pull/26467.
32+
# dict(repetition_penalty=0.1),
33+
# dict(bad_words=[]),
34+
]
35+
36+
default_params = dict(
37+
temperature=0.0, # greedy
38+
max_tokens=20,
39+
)
40+
41+
with monkeypatch.context() as m:
42+
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
43+
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
44+
45+
outputs = []
46+
for test_preemption in [False, True]:
47+
for executor in ["uni", "mp"]:
48+
for async_scheduling in [False, True]:
49+
cache_arg: dict[str, Any] = (
50+
dict(num_gpu_blocks_override=32)
51+
if test_preemption
52+
else dict(gpu_memory_utilization=0.7)
53+
)
54+
test_config = (
55+
f"executor={executor}, preemption={test_preemption},"
56+
f" async_sched={async_scheduling}"
57+
)
58+
print("-" * 80)
59+
print(f"---- TESTING: {test_config}")
60+
print("-" * 80)
61+
with VllmRunner(
62+
MODEL,
63+
max_model_len=512,
64+
enforce_eager=True,
65+
async_scheduling=async_scheduling,
66+
distributed_executor_backend=executor,
67+
dtype="float32", # avoid precision errors
68+
**cache_arg,
69+
) as vllm_model:
70+
results = []
71+
for override_params in sampling_param_tests:
72+
print(f"----------- RUNNING PARAMS: {override_params}")
73+
results.append(
74+
vllm_model.generate(
75+
example_prompts,
76+
sampling_params=SamplingParams(
77+
**default_params, **override_params
78+
),
79+
)
80+
)
81+
outputs.append((test_config, results))
82+
83+
baseline_config, baseline_tests = outputs[0]
84+
85+
for test_config, test_outputs in outputs[1:]:
86+
for base_outs, test_outs, params in zip(
87+
baseline_tests, test_outputs, sampling_param_tests
88+
):
89+
check_outputs_equal(
90+
outputs_0_lst=base_outs,
91+
outputs_1_lst=test_outs,
92+
name_0=f"baseline=[{baseline_config}], params={params}",
93+
name_1=f"config=[{test_config}], params={params}",
94+
)
95+
96+
print(f"PASSED: config=[{test_config}], params={params}")

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
754754
# Replace the existing block IDs with the new ones.
755755
req_state.block_ids = new_block_ids
756756

757+
if self.use_async_scheduling and num_output_tokens > 0:
758+
# We must recover the output token ids for resumed requests in the
759+
# async scheduling case, so that correct input_ids are obtained.
760+
resumed_token_ids = req_data.resumed_req_token_ids[i]
761+
assert resumed_token_ids is not None
762+
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
757763
if req_index is None:
758764
# The request is not in the persistent batch.
759765
# The request was either preempted and resumed later, or was not
@@ -991,7 +997,7 @@ def _prepare_input_ids(
991997
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
992998
if num_commmon_tokens == 0:
993999
# No requests in common with the previous iteration
994-
# So input_ids_cpu will have all the input ids.
1000+
# So input_ids.cpu will have all the input ids.
9951001
return
9961002
if indices_match and max_flattened_index == (num_commmon_tokens - 1):
9971003
# Common-case optimization: the batch is unchanged
@@ -1005,8 +1011,7 @@ def _prepare_input_ids(
10051011
if self.enable_prompt_embeds:
10061012
self.is_token_ids.gpu[:num_commmon_tokens] = True
10071013
return
1008-
# Upload the index tensors asynchronously
1009-
# so the scatter can be non-blocking.
1014+
# Upload the index tensors asynchronously so the scatter can be non-blocking.
10101015
input_ids_index_tensor = torch.tensor(
10111016
flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
10121017
).to(self.device, non_blocking=True)

0 commit comments

Comments
 (0)