Skip to content

Commit 96ed367

Browse files
njhillbigPYJ1151
authored andcommitted
[BugFix] Fix PP performance and PP kv connector output regression (#28768)
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: jiang1.li <jiang1.li@intel.com>
1 parent e734cd6 commit 96ed367

File tree

4 files changed

+105
-104
lines changed

4 files changed

+105
-104
lines changed

vllm/v1/engine/core.py

Lines changed: 66 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@
6363
from vllm.v1.request import Request, RequestStatus
6464
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
6565
from vllm.v1.structured_output import StructuredOutputManager
66-
from vllm.v1.utils import record_function_or_nullcontext
6766
from vllm.version import __version__ as VLLM_VERSION
6867

6968
logger = init_logger(__name__)
@@ -181,11 +180,13 @@ def __init__(
181180
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
182181
self.batch_queue = deque(maxlen=self.batch_queue_size)
183182

183+
self.ec_producer = (
184+
vllm_config.ec_transfer_config is not None
185+
and vllm_config.ec_transfer_config.is_ec_producer
186+
)
187+
184188
self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
185-
if (
186-
self.vllm_config.cache_config.enable_prefix_caching
187-
or kv_connector is not None
188-
):
189+
if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
189190
caching_hash_fn = get_hash_fn_by_name(
190191
vllm_config.cache_config.prefix_caching_hash_algo
191192
)
@@ -246,7 +247,7 @@ def _initialize_kv_caches(
246247

247248
elapsed = time.time() - start
248249
logger.info_once(
249-
("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
250+
"init engine (profile, create kv cache, warmup model) took %.2f seconds",
250251
elapsed,
251252
scope="local",
252253
)
@@ -312,6 +313,16 @@ def log_error_detail(self, scheduler_output: SchedulerOutput):
312313
)
313314
raise err
314315

316+
def _log_err_callback(self, scheduler_output: SchedulerOutput):
317+
"""Log error details of a future that's not expected to return a result."""
318+
319+
def callback(f, sched_output=scheduler_output):
320+
with self.log_error_detail(sched_output):
321+
result = f.result()
322+
assert result is None
323+
324+
return callback
325+
315326
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
316327
"""Schedule, execute, and make output.
317328
@@ -323,21 +334,17 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
323334
# or finished and not yet removed from the batch.
324335
if not self.scheduler.has_requests():
325336
return {}, False
326-
with record_function_or_nullcontext("core step: schedule"):
327-
scheduler_output = self.scheduler.schedule()
328-
329-
with record_function_or_nullcontext("core step: execute_model"):
330-
future = self.model_executor.execute_model(scheduler_output, non_block=True)
331-
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
332-
with self.log_error_detail(scheduler_output):
333-
model_output = future.result()
334-
if model_output is None:
335-
model_output = self.model_executor.sample_tokens(grammar_output)
336-
337-
with record_function_or_nullcontext("core step: update_from_output"):
338-
engine_core_outputs = self.scheduler.update_from_output(
339-
scheduler_output, model_output
340-
)
337+
scheduler_output = self.scheduler.schedule()
338+
future = self.model_executor.execute_model(scheduler_output, non_block=True)
339+
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
340+
with self.log_error_detail(scheduler_output):
341+
model_output = future.result()
342+
if model_output is None:
343+
model_output = self.model_executor.sample_tokens(grammar_output)
344+
345+
engine_core_outputs = self.scheduler.update_from_output(
346+
scheduler_output, model_output
347+
)
341348

342349
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
343350

@@ -378,52 +385,34 @@ def step_with_batch_queue(
378385
model_executed = False
379386
deferred_scheduler_output = None
380387
if self.scheduler.has_requests():
381-
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
382-
scheduler_output = self.scheduler.schedule()
383-
with record_function_or_nullcontext(
384-
"core step_with_batch_queue: execute_model"
385-
):
386-
exec_future = self.model_executor.execute_model(
387-
scheduler_output, non_block=True
388-
)
389-
model_executed = scheduler_output.total_num_scheduled_tokens > 0
388+
scheduler_output = self.scheduler.schedule()
389+
exec_future = self.model_executor.execute_model(
390+
scheduler_output, non_block=True
391+
)
392+
if not self.ec_producer:
393+
model_executed = scheduler_output.total_num_scheduled_tokens > 0
390394

391-
if scheduler_output.pending_structured_output_tokens:
392-
with record_function_or_nullcontext(
393-
"core step_with_batch_queue: pending_structured_output_tokens"
394-
):
395-
# We need to defer sampling until we have processed the model output
396-
# from the prior step.
397-
deferred_scheduler_output = scheduler_output
398-
# Block-wait for execute to return
399-
# (continues running async on the GPU).
400-
with self.log_error_detail(scheduler_output):
401-
exec_result = exec_future.result()
402-
assert exec_result is None
395+
if not model_executed:
396+
# No sampling required (no requests scheduled).
397+
future = cast(Future[ModelRunnerOutput], exec_future)
403398
else:
404-
with record_function_or_nullcontext(
405-
"core step_with_batch_queue: get_grammar_bitmask"
406-
):
407-
# We aren't waiting for any tokens, get any grammar
408-
# output immediately.
399+
exec_future.add_done_callback(self._log_err_callback(scheduler_output))
400+
401+
if not scheduler_output.pending_structured_output_tokens:
402+
# We aren't waiting for any tokens, get any grammar output
403+
# and sample immediately.
409404
grammar_output = self.scheduler.get_grammar_bitmask(
410405
scheduler_output
411406
)
412-
# Block-wait for execute to return (continues running async on the GPU).
413-
with self.log_error_detail(scheduler_output):
414-
exec_result = exec_future.result()
415-
416-
if exec_result is None:
417-
with record_function_or_nullcontext(
418-
"core step_with_batch_queue: sample_tokens"
419-
):
420-
# Call sample tokens.
421-
future = self.model_executor.sample_tokens(
422-
grammar_output, non_block=True
423-
)
407+
future = self.model_executor.sample_tokens(
408+
grammar_output, non_block=True
409+
)
424410
else:
425-
# No sampling required (e.g. all requests finished).
426-
future = cast(Future[ModelRunnerOutput], exec_future)
411+
# We need to defer sampling until we have processed the model output
412+
# from the prior step.
413+
deferred_scheduler_output = scheduler_output
414+
415+
if not deferred_scheduler_output:
427416
# Add this step's future to the queue.
428417
batch_queue.appendleft((future, scheduler_output))
429418
if (
@@ -440,34 +429,27 @@ def step_with_batch_queue(
440429
# only be called when the scheduler contains requests or the queue
441430
# is non-empty.
442431
return None, False
443-
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
444-
# Block until the next result is available.
445-
future, scheduler_output = batch_queue.pop()
446-
with self.log_error_detail(scheduler_output):
447-
model_output = future.result()
448-
with record_function_or_nullcontext(
449-
"core step_with_batch_queue: update_from_output"
450-
):
451-
engine_core_outputs = self.scheduler.update_from_output(
452-
scheduler_output, model_output
453-
)
432+
433+
# Block until the next result is available.
434+
future, scheduler_output = batch_queue.pop()
435+
with self.log_error_detail(scheduler_output):
436+
model_output = future.result()
437+
438+
engine_core_outputs = self.scheduler.update_from_output(
439+
scheduler_output, model_output
440+
)
454441

455442
# NOTE(nick): We can either handle the deferred tasks here or save
456443
# in a field and do it immediately once step_with_batch_queue is
457444
# re-called. The latter slightly favors TTFT over TPOT/throughput.
458445
if deferred_scheduler_output:
459-
with record_function_or_nullcontext(
460-
"core step_with_batch_queue: deferred_scheduler_output"
461-
):
462-
# We now have the tokens needed to compute the bitmask for the
463-
# deferred request. Get the bitmask and call sample tokens.
464-
grammar_output = self.scheduler.get_grammar_bitmask(
465-
deferred_scheduler_output
466-
)
467-
future = self.model_executor.sample_tokens(
468-
grammar_output, non_block=True
469-
)
470-
batch_queue.appendleft((future, deferred_scheduler_output))
446+
# We now have the tokens needed to compute the bitmask for the
447+
# deferred request. Get the bitmask and call sample tokens.
448+
grammar_output = self.scheduler.get_grammar_bitmask(
449+
deferred_scheduler_output
450+
)
451+
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
452+
batch_queue.appendleft((future, deferred_scheduler_output))
471453

472454
return engine_core_outputs, model_executed
473455

vllm/v1/executor/ray_executor.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def _init_executor(self) -> None:
9999
# KV connector setup
100100
self.has_connector = self.vllm_config.kv_transfer_config is not None
101101

102+
self.ec_producer = (
103+
self.vllm_config.ec_transfer_config is not None
104+
and self.vllm_config.ec_transfer_config.is_ec_producer
105+
)
106+
102107
self.scheduler_output: SchedulerOutput | None = None
103108

104109
@property
@@ -395,6 +400,12 @@ def execute_model( # type: ignore[override]
395400
"State error: sample_tokens() must be called "
396401
"after execute_model() returns None."
397402
)
403+
404+
if self.ec_producer or not scheduler_output.total_num_scheduled_tokens:
405+
# Model will not execute, call model runner immediately.
406+
return self._execute_dag(scheduler_output, None, non_block)
407+
408+
# Model will execute, defer to sample_tokens() call.
398409
self.scheduler_output = scheduler_output
399410
return COMPLETED_NONE_FUTURE if non_block else None
400411

@@ -417,10 +428,18 @@ def sample_tokens( # type: ignore[override]
417428
"""
418429
scheduler_output = self.scheduler_output
419430
if scheduler_output is None:
420-
return None # noqa
431+
return COMPLETED_NONE_FUTURE if non_block else None # noqa
421432

422433
self.scheduler_output = None
423434

435+
return self._execute_dag(scheduler_output, grammar_output, non_block)
436+
437+
def _execute_dag(
438+
self,
439+
scheduler_output: SchedulerOutput,
440+
grammar_output: "GrammarOutput | None",
441+
non_block: bool = False,
442+
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
424443
# Build the compiled DAG for the first time.
425444
if self.forward_dag is None: # type: ignore
426445
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import defaultdict
88
from collections.abc import Iterator
99
from contextlib import contextmanager
10-
from copy import deepcopy
10+
from copy import copy, deepcopy
1111
from functools import reduce
1212
from itertools import product
1313
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
@@ -250,7 +250,6 @@ class ExecuteModelState(NamedTuple):
250250
hidden_states: torch.Tensor
251251
sample_hidden_states: torch.Tensor
252252
aux_hidden_states: list[torch.Tensor] | None
253-
kv_connector_output: KVConnectorOutput | None
254253
ec_connector_output: ECConnectorOutput | None
255254

256255

@@ -573,6 +572,7 @@ def __init__(
573572

574573
# Ephemeral state transferred between execute_model() and sample_tokens().
575574
self.execute_model_state: ExecuteModelState | None = None
575+
self.kv_connector_output: KVConnectorOutput | None = None
576576

577577
def reset_mm_cache(self) -> None:
578578
if self.mm_budget:
@@ -2803,6 +2803,7 @@ def execute_model(
28032803
# Return the intermediate tensors.
28042804
assert isinstance(hidden_states, IntermediateTensors)
28052805
hidden_states.kv_connector_output = kv_connector_output
2806+
self.kv_connector_output = kv_connector_output
28062807
return hidden_states
28072808

28082809
if self.is_pooling_model:
@@ -2853,18 +2854,31 @@ def execute_model(
28532854
hidden_states,
28542855
sample_hidden_states,
28552856
aux_hidden_states,
2856-
kv_connector_output,
28572857
ec_connector_output,
28582858
)
2859+
self.kv_connector_output = kv_connector_output
28592860
return None
28602861

28612862
@torch.inference_mode
28622863
def sample_tokens(
28632864
self, grammar_output: "GrammarOutput | None"
28642865
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
2866+
kv_connector_output = self.kv_connector_output
2867+
self.kv_connector_output = None
2868+
28652869
if self.execute_model_state is None:
28662870
# Nothing to do (PP non-final rank case), output isn't used.
2867-
return None # noqa
2871+
if not kv_connector_output:
2872+
return None # noqa
2873+
2874+
# In case of PP with kv transfer, we need to pass through the
2875+
# kv_connector_output
2876+
if kv_connector_output.is_empty():
2877+
return EMPTY_MODEL_RUNNER_OUTPUT
2878+
2879+
output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
2880+
output.kv_connector_output = kv_connector_output
2881+
return output
28682882

28692883
# Unpack ephemeral state.
28702884
(
@@ -2875,7 +2889,6 @@ def sample_tokens(
28752889
hidden_states,
28762890
sample_hidden_states,
28772891
aux_hidden_states,
2878-
kv_connector_output,
28792892
ec_connector_output,
28802893
) = self.execute_model_state
28812894
# Clear ephemeral state.

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""A GPU worker class."""
44

5-
import copy
65
import gc
76
import os
87
from contextlib import AbstractContextManager, nullcontext
@@ -45,7 +44,6 @@
4544
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
4645
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
4746
from vllm.v1.outputs import (
48-
EMPTY_MODEL_RUNNER_OUTPUT,
4947
AsyncModelRunnerOutput,
5048
DraftTokenIds,
5149
ModelRunnerOutput,
@@ -581,18 +579,7 @@ def execute_model(
581579
all_gather_tensors=all_gather_tensors,
582580
)
583581

584-
kv_connector_output = output.kv_connector_output
585-
if not kv_connector_output:
586-
return None
587-
588-
# In case of PP with kv transfer, we need to pass through the
589-
# kv_connector_output
590-
if kv_connector_output.is_empty():
591-
return EMPTY_MODEL_RUNNER_OUTPUT
592-
593-
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
594-
output.kv_connector_output = kv_connector_output
595-
return output
582+
return None
596583

597584
def take_draft_token_ids(self) -> DraftTokenIds | None:
598585
return self.model_runner.take_draft_token_ids()

0 commit comments

Comments
 (0)