Skip to content

Commit 8e7a891

Browse files
authored
[BugFix] Fix spec decoding max_tokens scheduling perf issue (#29542)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 953d9c8 commit 8e7a891

File tree

3 files changed

+28
-38
lines changed

3 files changed

+28
-38
lines changed

tests/v1/test_outputs.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ def test_slice_without_cu_num_generated_tokens(self):
4343
cu_num_generated_tokens=None,
4444
)
4545

46-
sliced = logprobsLists.slice(1, 3)
46+
sliced = logprobsLists.slice_request(1, num_positions=2)
4747
assert sliced.logprob_token_ids == [[2], [3]]
4848
assert sliced.logprobs == [[0.2], [0.3]]
4949
assert sliced.sampled_token_ranks == [2, 3]
5050
assert sliced.cu_num_generated_tokens is None
5151

5252
def test_slice_from_start(self):
5353
"""Test slicing from the start position"""
54-
sliced = self.logprobsLists.slice(0, 2)
54+
sliced = self.logprobsLists.slice_request(0, num_positions=5)
5555
assert len(sliced.logprob_token_ids) == 5
5656
assert sliced.logprob_token_ids == [
5757
[1, 2],
@@ -60,11 +60,11 @@ def test_slice_from_start(self):
6060
[7, 8],
6161
[9, 10],
6262
]
63-
assert sliced.cu_num_generated_tokens == [0, 2, 5]
63+
assert sliced.cu_num_generated_tokens is None
6464

6565
def test_slice_from_middle(self):
6666
"""Test slicing from the middle position"""
67-
sliced = self.logprobsLists.slice(1, 3)
67+
sliced = self.logprobsLists.slice_request(1, num_positions=7)
6868
assert len(sliced.logprob_token_ids) == 7
6969
assert sliced.logprob_token_ids == [
7070
[5, 6],
@@ -75,27 +75,25 @@ def test_slice_from_middle(self):
7575
[15, 16],
7676
[17, 18],
7777
]
78-
assert sliced.cu_num_generated_tokens == [0, 3, 7]
78+
assert sliced.cu_num_generated_tokens is None
7979

8080
def test_slice_single_request(self):
8181
"""Test slicing a single request"""
82-
sliced = self.logprobsLists.slice(1, 2)
82+
sliced = self.logprobsLists.slice_request(1, num_positions=3)
8383
assert len(sliced.logprob_token_ids) == 3
8484
assert sliced.logprob_token_ids == [[5, 6], [7, 8], [9, 10]]
85-
assert sliced.cu_num_generated_tokens == [0, 3]
85+
assert sliced.cu_num_generated_tokens is None
8686

8787
def test_slice_last_request(self):
8888
"""Test slicing the last request"""
89-
sliced = self.logprobsLists.slice(2, 3)
89+
sliced = self.logprobsLists.slice_request(2, num_positions=4)
9090
assert len(sliced.logprob_token_ids) == 4
9191
assert sliced.logprob_token_ids == [[11, 12], [13, 14], [15, 16], [17, 18]]
92-
assert sliced.cu_num_generated_tokens == [0, 4]
92+
assert sliced.cu_num_generated_tokens is None
9393

9494
def test_slice_all_requests(self):
9595
"""Test slicing all requests (full slice)"""
96-
sliced = self.logprobsLists.slice(0, 3)
96+
sliced = self.logprobsLists.slice_request(0, num_positions=9)
9797
assert len(sliced.logprob_token_ids) == 9 # All tokens
9898
assert sliced.logprob_token_ids == self.logprobsLists.logprob_token_ids
99-
assert (
100-
sliced.cu_num_generated_tokens == self.logprobsLists.cu_num_generated_tokens
101-
)
99+
assert sliced.cu_num_generated_tokens is None

vllm/v1/core/sched/scheduler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,15 @@ def schedule(self) -> SchedulerOutput:
234234
num_new_tokens = self.scheduler_config.long_prefill_token_threshold
235235
num_new_tokens = min(num_new_tokens, token_budget)
236236

237-
# Make sure the input position does not exceed the max model len or
238-
# request's max_tokens.
239-
# This is necessary when using spec decoding and/or async scheduling.
237+
num_spec_placeholders = max(0, request.num_output_placeholders - 1)
240238
max_total_tokens = min(
241-
request.num_prompt_tokens + request.max_tokens, self.max_model_len
239+
# Avoid scheduling tokens that we're sure won't will be needed based on
240+
# request.max_tokens. For this calculation we assume placeholder
241+
# speculated output tokens are rejected.
242+
request.num_prompt_tokens + request.max_tokens + num_spec_placeholders,
243+
# Make sure the input position does not exceed the max model len.
244+
# This is necessary when using spec decoding.
245+
self.max_model_len,
242246
)
243247
num_new_tokens = min(
244248
num_new_tokens, max_total_tokens - 1 - request.num_computed_tokens
@@ -1089,7 +1093,7 @@ def update_from_output(
10891093
and request.sampling_params.logprobs is not None
10901094
and logprobs
10911095
):
1092-
new_logprobs = logprobs.slice(req_index, req_index + 1)
1096+
new_logprobs = logprobs.slice_request(req_index, len(new_token_ids))
10931097

10941098
if new_token_ids and self.structured_output_manager.should_advance(request):
10951099
struct_output_request = request.structured_output_request

vllm/v1/outputs.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,15 @@ class LogprobsLists(NamedTuple):
2929
# different for each request.
3030
cu_num_generated_tokens: list[int] | None = None
3131

32-
def slice(self, start_req_idx: int, end_req_idx: int):
33-
if self.cu_num_generated_tokens:
34-
start = self.cu_num_generated_tokens[start_req_idx]
35-
end = self.cu_num_generated_tokens[end_req_idx]
36-
# Recompute cumulative array starting from 0
37-
cu_num_offset = self.cu_num_generated_tokens[start_req_idx]
38-
sliced_cu_num_generated_tokens = [
39-
cu_num - cu_num_offset
40-
for cu_num in self.cu_num_generated_tokens[
41-
start_req_idx : end_req_idx + 1
42-
]
43-
]
44-
else:
45-
start = start_req_idx
46-
end = end_req_idx
47-
sliced_cu_num_generated_tokens = None
32+
def slice_request(self, req_idx: int, num_positions: int):
33+
if self.cu_num_generated_tokens is not None:
34+
req_idx = self.cu_num_generated_tokens[req_idx]
35+
end_idx = req_idx + num_positions
4836
return LogprobsLists(
49-
self.logprob_token_ids[start:end],
50-
self.logprobs[start:end],
51-
self.sampled_token_ranks[start:end],
52-
sliced_cu_num_generated_tokens,
37+
self.logprob_token_ids[req_idx:end_idx],
38+
self.logprobs[req_idx:end_idx],
39+
self.sampled_token_ranks[req_idx:end_idx],
40+
None,
5341
)
5442

5543

0 commit comments

Comments
 (0)