Skip to content

Commit 818983d

Browse files
committed
[BugFix] Fix PP performance and PP kv connector output regression
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent c1c005e commit 818983d

File tree

3 files changed

+55
-44
lines changed

3 files changed

+55
-44
lines changed

vllm/v1/engine/core.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,13 @@ def __init__(
180180
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
181181
self.batch_queue = deque(maxlen=self.batch_queue_size)
182182

183+
self.ec_producer = (
184+
vllm_config.ec_transfer_config is not None
185+
and vllm_config.ec_transfer_config.is_ec_producer
186+
)
187+
183188
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
184-
if (
185-
self.vllm_config.cache_config.enable_prefix_caching
186-
or kv_connector is not None
187-
):
189+
if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
188190
caching_hash_fn = get_hash_fn_by_name(
189191
vllm_config.cache_config.prefix_caching_hash_algo
190192
)
@@ -244,7 +246,7 @@ def _initialize_kv_caches(
244246

245247
elapsed = time.time() - start
246248
logger.info_once(
247-
("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
249+
"init engine (profile, create kv cache, warmup model) took %.2f seconds",
248250
elapsed,
249251
scope="local",
250252
)
@@ -310,6 +312,16 @@ def log_error_detail(self, scheduler_output: SchedulerOutput):
310312
)
311313
raise err
312314

315+
def _log_err_callback(self, scheduler_output: SchedulerOutput):
316+
"""Log error details of a future that's not expected to return a result."""
317+
318+
def callback(f, sched_output=scheduler_output):
319+
with self.log_error_detail(sched_output):
320+
result = f.result()
321+
assert result is None
322+
323+
return callback
324+
313325
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
314326
"""Schedule, execute, and make output.
315327
@@ -373,31 +385,30 @@ def step_with_batch_queue(
373385
exec_future = self.model_executor.execute_model(
374386
scheduler_output, non_block=True
375387
)
376-
model_executed = scheduler_output.total_num_scheduled_tokens > 0
377-
378-
if scheduler_output.pending_structured_output_tokens:
379-
# We need to defer sampling until we have processed the model output
380-
# from the prior step.
381-
deferred_scheduler_output = scheduler_output
382-
# Block-wait for execute to return (continues running async on the GPU).
383-
with self.log_error_detail(scheduler_output):
384-
exec_result = exec_future.result()
385-
assert exec_result is None
388+
if not self.ec_producer:
389+
model_executed = scheduler_output.total_num_scheduled_tokens > 0
390+
391+
if not model_executed:
392+
# No sampling required (no requests scheduled).
393+
future = cast(Future[ModelRunnerOutput], exec_future)
386394
else:
387-
# We aren't waiting for any tokens, get any grammar output immediately.
388-
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
389-
# Block-wait for execute to return (continues running async on the GPU).
390-
with self.log_error_detail(scheduler_output):
391-
exec_result = exec_future.result()
392-
393-
if exec_result is None:
394-
# Call sample tokens.
395+
exec_future.add_done_callback(self._log_err_callback(scheduler_output))
396+
397+
if not scheduler_output.pending_structured_output_tokens:
398+
# We aren't waiting for any tokens, get any grammar output
399+
# and sample immediately.
400+
grammar_output = self.scheduler.get_grammar_bitmask(
401+
scheduler_output
402+
)
395403
future = self.model_executor.sample_tokens(
396404
grammar_output, non_block=True
397405
)
398406
else:
399-
# No sampling required (e.g. all requests finished).
400-
future = cast(Future[ModelRunnerOutput], exec_future)
407+
# We need to defer sampling until we have processed the model output
408+
# from the prior step.
409+
deferred_scheduler_output = scheduler_output
410+
411+
if not deferred_scheduler_output:
401412
# Add this step's future to the queue.
402413
batch_queue.appendleft((future, scheduler_output))
403414
if (

vllm/v1/worker/gpu_model_runner.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import defaultdict
88
from collections.abc import Iterator
99
from contextlib import contextmanager
10-
from copy import deepcopy
10+
from copy import copy, deepcopy
1111
from functools import reduce
1212
from itertools import product
1313
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
@@ -240,7 +240,6 @@ class ExecuteModelState(NamedTuple):
240240
hidden_states: torch.Tensor
241241
sample_hidden_states: torch.Tensor
242242
aux_hidden_states: list[torch.Tensor] | None
243-
kv_connector_output: KVConnectorOutput | None
244243
ec_connector_output: ECConnectorOutput | None
245244

246245

@@ -549,6 +548,7 @@ def __init__(
549548

550549
# Ephemeral state transferred between execute_model() and sample_tokens().
551550
self.execute_model_state: ExecuteModelState | None = None
551+
self.kv_connector_output: KVConnectorOutput | None = None
552552

553553
def reset_mm_cache(self) -> None:
554554
if self.mm_budget:
@@ -2674,6 +2674,7 @@ def execute_model(
26742674
# Return the intermediate tensors.
26752675
assert isinstance(hidden_states, IntermediateTensors)
26762676
hidden_states.kv_connector_output = kv_connector_output
2677+
self.kv_connector_output = kv_connector_output
26772678
return hidden_states
26782679

26792680
if self.is_pooling_model:
@@ -2724,18 +2725,31 @@ def execute_model(
27242725
hidden_states,
27252726
sample_hidden_states,
27262727
aux_hidden_states,
2727-
kv_connector_output,
27282728
ec_connector_output,
27292729
)
2730+
self.kv_connector_output = kv_connector_output
27302731
return None
27312732

27322733
@torch.inference_mode
27332734
def sample_tokens(
27342735
self, grammar_output: "GrammarOutput | None"
27352736
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
2737+
kv_connector_output = self.kv_connector_output
2738+
self.kv_connector_output = None
2739+
27362740
if self.execute_model_state is None:
27372741
# Nothing to do (PP non-final rank case), output isn't used.
2738-
return None # noqa
2742+
if not kv_connector_output:
2743+
return None # noqa
2744+
2745+
# In case of PP with kv transfer, we need to pass through the
2746+
# kv_connector_output
2747+
if kv_connector_output.is_empty():
2748+
return EMPTY_MODEL_RUNNER_OUTPUT
2749+
2750+
output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
2751+
output.kv_connector_output = kv_connector_output
2752+
return output
27392753

27402754
# Unpack ephemeral state.
27412755
(
@@ -2746,7 +2760,6 @@ def sample_tokens(
27462760
hidden_states,
27472761
sample_hidden_states,
27482762
aux_hidden_states,
2749-
kv_connector_output,
27502763
ec_connector_output,
27512764
) = self.execute_model_state
27522765
# Clear ephemeral state.

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""A GPU worker class."""
44

5-
import copy
65
import gc
76
import os
87
from contextlib import AbstractContextManager, nullcontext
@@ -45,7 +44,6 @@
4544
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
4645
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
4746
from vllm.v1.outputs import (
48-
EMPTY_MODEL_RUNNER_OUTPUT,
4947
AsyncModelRunnerOutput,
5048
DraftTokenIds,
5149
ModelRunnerOutput,
@@ -573,18 +571,7 @@ def execute_model(
573571
all_gather_tensors=all_gather_tensors,
574572
)
575573

576-
kv_connector_output = output.kv_connector_output
577-
if not kv_connector_output:
578-
return None
579-
580-
# In case of PP with kv transfer, we need to pass through the
581-
# kv_connector_output
582-
if kv_connector_output.is_empty():
583-
return EMPTY_MODEL_RUNNER_OUTPUT
584-
585-
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
586-
output.kv_connector_output = kv_connector_output
587-
return output
574+
return None
588575

589576
def take_draft_token_ids(self) -> DraftTokenIds | None:
590577
return self.model_runner.take_draft_token_ids()

0 commit comments

Comments
 (0)