Skip to content

Commit bf6ecea

Browse files
committed
[BugFix] Fix PP performance and PP kv connector output regression
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 6d7de5f commit bf6ecea

File tree

3 files changed

+54
-49
lines changed

3 files changed

+54
-49
lines changed

vllm/v1/engine/core.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from contextlib import ExitStack, contextmanager
1212
from inspect import isclass, signature
1313
from logging import DEBUG
14-
from typing import Any, TypeVar, cast
14+
from typing import Any, TypeVar
1515

1616
import msgspec
1717
import zmq
@@ -180,11 +180,13 @@ def __init__(
180180
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
181181
self.batch_queue = deque(maxlen=self.batch_queue_size)
182182

183+
self.ec_producer = (
184+
vllm_config.ec_transfer_config is not None
185+
and vllm_config.ec_transfer_config.is_ec_producer
186+
)
187+
183188
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
184-
if (
185-
self.vllm_config.cache_config.enable_prefix_caching
186-
or kv_connector is not None
187-
):
189+
if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
188190
caching_hash_fn = get_hash_fn_by_name(
189191
vllm_config.cache_config.prefix_caching_hash_algo
190192
)
@@ -244,7 +246,7 @@ def _initialize_kv_caches(
244246

245247
elapsed = time.time() - start
246248
logger.info_once(
247-
("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
249+
"init engine (profile, create kv cache, warmup model) took %.2f seconds",
248250
elapsed,
249251
scope="local",
250252
)
@@ -310,6 +312,16 @@ def log_error_detail(self, scheduler_output: SchedulerOutput):
310312
)
311313
raise err
312314

315+
def _log_err_callback(self, scheduler_output: SchedulerOutput):
316+
"""Log error details of a future that's not expected to return a result."""
317+
318+
def callback(f, sched_output=scheduler_output):
319+
with self.log_error_detail(sched_output):
320+
result = f.result()
321+
assert result is None
322+
323+
return callback
324+
313325
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
314326
"""Schedule, execute, and make output.
315327
@@ -370,34 +382,28 @@ def step_with_batch_queue(
370382
deferred_scheduler_output = None
371383
if self.scheduler.has_requests():
372384
scheduler_output = self.scheduler.schedule()
373-
exec_future = self.model_executor.execute_model(
374-
scheduler_output, non_block=True
375-
)
376-
model_executed = scheduler_output.total_num_scheduled_tokens > 0
377-
378-
if scheduler_output.pending_structured_output_tokens:
379-
# We need to defer sampling until we have processed the model output
380-
# from the prior step.
381-
deferred_scheduler_output = scheduler_output
382-
# Block-wait for execute to return (continues running async on the GPU).
383-
with self.log_error_detail(scheduler_output):
384-
exec_result = exec_future.result()
385-
assert exec_result is None
386-
else:
387-
# We aren't waiting for any tokens, get any grammar output immediately.
388-
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
389-
# Block-wait for execute to return (continues running async on the GPU).
390-
with self.log_error_detail(scheduler_output):
391-
exec_result = exec_future.result()
392-
393-
if exec_result is None:
394-
# Call sample tokens.
385+
future = self.model_executor.execute_model(scheduler_output, non_block=True)
386+
if not self.ec_producer:
387+
model_executed = scheduler_output.total_num_scheduled_tokens > 0
388+
389+
if model_executed:
390+
future.add_done_callback(self._log_err_callback(scheduler_output))
391+
392+
if not scheduler_output.pending_structured_output_tokens:
393+
# We aren't waiting for any tokens, get any grammar output
394+
# and sample immediately.
395+
grammar_output = self.scheduler.get_grammar_bitmask(
396+
scheduler_output
397+
)
395398
future = self.model_executor.sample_tokens(
396399
grammar_output, non_block=True
397400
)
398401
else:
399-
# No sampling required (e.g. all requests finished).
400-
future = cast(Future[ModelRunnerOutput], exec_future)
402+
# We need to defer sampling until we have processed the model output
403+
# from the prior step.
404+
deferred_scheduler_output = scheduler_output
405+
406+
if not deferred_scheduler_output:
401407
# Add this step's future to the queue.
402408
batch_queue.appendleft((future, scheduler_output))
403409
if (

vllm/v1/worker/gpu_model_runner.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import defaultdict
88
from collections.abc import Iterator
99
from contextlib import contextmanager
10-
from copy import deepcopy
10+
from copy import copy, deepcopy
1111
from functools import reduce
1212
from itertools import product
1313
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
@@ -242,7 +242,6 @@ class ExecuteModelState(NamedTuple):
242242
hidden_states: torch.Tensor
243243
sample_hidden_states: torch.Tensor
244244
aux_hidden_states: list[torch.Tensor] | None
245-
kv_connector_output: KVConnectorOutput | None
246245
ec_connector_output: ECConnectorOutput | None
247246

248247

@@ -551,6 +550,7 @@ def __init__(
551550

552551
# Ephemeral state transferred between execute_model() and sample_tokens().
553552
self.execute_model_state: ExecuteModelState | None = None
553+
self.kv_connector_output: KVConnectorOutput | None = None
554554

555555
def reset_mm_cache(self) -> None:
556556
if self.mm_budget:
@@ -2732,18 +2732,31 @@ def execute_model(
27322732
hidden_states,
27332733
sample_hidden_states,
27342734
aux_hidden_states,
2735-
kv_connector_output,
27362735
ec_connector_output,
27372736
)
2737+
self.kv_connector_output = kv_connector_output
27382738
return None
27392739

27402740
@torch.inference_mode
27412741
def sample_tokens(
27422742
self, grammar_output: "GrammarOutput | None"
27432743
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
2744+
kv_connector_output = self.kv_connector_output
2745+
self.kv_connector_output = None
2746+
27442747
if self.execute_model_state is None:
27452748
# Nothing to do (PP non-final rank case), output isn't used.
2746-
return None # noqa
2749+
if not kv_connector_output:
2750+
return None # noqa
2751+
2752+
# In case of PP with kv transfer, we need to pass through the
2753+
# kv_connector_output
2754+
if kv_connector_output.is_empty():
2755+
return EMPTY_MODEL_RUNNER_OUTPUT
2756+
2757+
output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
2758+
output.kv_connector_output = kv_connector_output
2759+
return output
27472760

27482761
# Unpack ephemeral state.
27492762
(
@@ -2754,7 +2767,6 @@ def sample_tokens(
27542767
hidden_states,
27552768
sample_hidden_states,
27562769
aux_hidden_states,
2757-
kv_connector_output,
27582770
ec_connector_output,
27592771
) = self.execute_model_state
27602772
# Clear ephemeral state.

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""A GPU worker class."""
44

5-
import copy
65
import gc
76
import os
87
from contextlib import AbstractContextManager, nullcontext
@@ -45,7 +44,6 @@
4544
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
4645
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
4746
from vllm.v1.outputs import (
48-
EMPTY_MODEL_RUNNER_OUTPUT,
4947
AsyncModelRunnerOutput,
5048
DraftTokenIds,
5149
ModelRunnerOutput,
@@ -573,18 +571,7 @@ def execute_model(
573571
all_gather_tensors=all_gather_tensors,
574572
)
575573

576-
kv_connector_output = output.kv_connector_output
577-
if not kv_connector_output:
578-
return None
579-
580-
# In case of PP with kv transfer, we need to pass through the
581-
# kv_connector_output
582-
if kv_connector_output.is_empty():
583-
return EMPTY_MODEL_RUNNER_OUTPUT
584-
585-
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
586-
output.kv_connector_output = kv_connector_output
587-
return output
574+
return None
588575

589576
def take_draft_token_ids(self) -> DraftTokenIds | None:
590577
return self.model_runner.take_draft_token_ids()

0 commit comments

Comments
 (0)