Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 66 additions & 84 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -181,11 +180,13 @@ def __init__(
logger.info("Batch queue is enabled with size %d", self.batch_queue_size)
self.batch_queue = deque(maxlen=self.batch_queue_size)

self.ec_producer = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
self.ec_producer = (
self.is_ec_producer = (

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll open follow-on PR so that this doesn't hold up the release

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon opened follow-on #28884 for this.

vllm_config.ec_transfer_config is not None
and vllm_config.ec_transfer_config.is_ec_producer
)

self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None
if (
self.vllm_config.cache_config.enable_prefix_caching
or kv_connector is not None
):
if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None:
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo
)
Expand Down Expand Up @@ -245,7 +246,7 @@ def _initialize_kv_caches(

elapsed = time.time() - start
logger.info_once(
("init engine (profile, create kv cache, warmup model) took %.2f seconds"),
"init engine (profile, create kv cache, warmup model) took %.2f seconds",
elapsed,
scope="local",
)
Expand Down Expand Up @@ -311,6 +312,16 @@ def log_error_detail(self, scheduler_output: SchedulerOutput):
)
raise err

def _log_err_callback(self, scheduler_output: SchedulerOutput):
"""Log error details of a future that's not expected to return a result."""

def callback(f, sched_output=scheduler_output):
with self.log_error_detail(sched_output):
result = f.result()
assert result is None

return callback

def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.

Expand All @@ -322,21 +333,17 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
with record_function_or_nullcontext("core step: schedule"):
scheduler_output = self.scheduler.schedule()

with record_function_or_nullcontext("core step: execute_model"):
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)

with record_function_or_nullcontext("core step: update_from_output"):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)

engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)

return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0

Expand Down Expand Up @@ -374,52 +381,34 @@ def step_with_batch_queue(
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
scheduler_output = self.scheduler.schedule()
with record_function_or_nullcontext(
"core step_with_batch_queue: execute_model"
):
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0
scheduler_output = self.scheduler.schedule()
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
if not self.ec_producer:
model_executed = scheduler_output.total_num_scheduled_tokens > 0

if scheduler_output.pending_structured_output_tokens:
with record_function_or_nullcontext(
"core step_with_batch_queue: pending_structured_output_tokens"
):
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return
# (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
if not model_executed:
# No sampling required (no requests scheduled).
future = cast(Future[ModelRunnerOutput], exec_future)
else:
with record_function_or_nullcontext(
"core step_with_batch_queue: get_grammar_bitmask"
):
# We aren't waiting for any tokens, get any grammar
# output immediately.
exec_future.add_done_callback(self._log_err_callback(scheduler_output))

if not scheduler_output.pending_structured_output_tokens:
# We aren't waiting for any tokens, get any grammar output
# and sample immediately.
grammar_output = self.scheduler.get_grammar_bitmask(
scheduler_output
)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()

if exec_result is None:
with record_function_or_nullcontext(
"core step_with_batch_queue: sample_tokens"
):
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# No sampling required (e.g. all requests finished).
future = cast(Future[ModelRunnerOutput], exec_future)
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output

if not deferred_scheduler_output:
# Add this step's future to the queue.
batch_queue.appendleft((future, scheduler_output))
if (
Expand All @@ -436,34 +425,27 @@ def step_with_batch_queue(
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()
with record_function_or_nullcontext(
"core step_with_batch_queue: update_from_output"
):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)

# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()

engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)

# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
with record_function_or_nullcontext(
"core step_with_batch_queue: deferred_scheduler_output"
):
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
batch_queue.appendleft((future, deferred_scheduler_output))
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))

return engine_core_outputs, model_executed

Expand Down
21 changes: 20 additions & 1 deletion vllm/v1/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def _init_executor(self) -> None:
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None

self.ec_producer = (
self.vllm_config.ec_transfer_config is not None
and self.vllm_config.ec_transfer_config.is_ec_producer
)

self.scheduler_output: SchedulerOutput | None = None

@property
Expand Down Expand Up @@ -395,6 +400,12 @@ def execute_model( # type: ignore[override]
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)

if self.ec_producer or not scheduler_output.total_num_scheduled_tokens:
# Model will not execute, call model runner immediately.
return self._execute_dag(scheduler_output, None, non_block)

# Model will execute, defer to sample_tokens() call.
self.scheduler_output = scheduler_output
return COMPLETED_NONE_FUTURE if non_block else None

Expand All @@ -417,10 +428,18 @@ def sample_tokens( # type: ignore[override]
"""
scheduler_output = self.scheduler_output
if scheduler_output is None:
return None # noqa
return COMPLETED_NONE_FUTURE if non_block else None # noqa

self.scheduler_output = None

return self._execute_dag(scheduler_output, grammar_output, non_block)

def _execute_dag(
self,
scheduler_output: SchedulerOutput,
grammar_output: "GrammarOutput | None",
non_block: bool = False,
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
Expand Down
23 changes: 18 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from copy import copy, deepcopy
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
Expand Down Expand Up @@ -242,7 +242,6 @@ class ExecuteModelState(NamedTuple):
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
ec_connector_output: ECConnectorOutput | None


Expand Down Expand Up @@ -551,6 +550,7 @@ def __init__(

# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
self.kv_connector_output: KVConnectorOutput | None = None

def reset_mm_cache(self) -> None:
if self.mm_budget:
Expand Down Expand Up @@ -2682,6 +2682,7 @@ def execute_model(
# Return the intermediate tensors.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
self.kv_connector_output = kv_connector_output
return hidden_states

if self.is_pooling_model:
Expand Down Expand Up @@ -2732,18 +2733,31 @@ def execute_model(
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
ec_connector_output,
)
self.kv_connector_output = kv_connector_output
return None

@torch.inference_mode
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
kv_connector_output = self.kv_connector_output
self.kv_connector_output = None

if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # noqa
if not kv_connector_output:
return None # noqa

# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT

output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output

# Unpack ephemeral state.
(
Expand All @@ -2754,7 +2768,6 @@ def sample_tokens(
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
ec_connector_output,
) = self.execute_model_state
# Clear ephemeral state.
Expand Down
15 changes: 1 addition & 14 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""

import copy
import gc
import os
from contextlib import AbstractContextManager, nullcontext
Expand Down Expand Up @@ -45,7 +44,6 @@
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput,
DraftTokenIds,
ModelRunnerOutput,
Expand Down Expand Up @@ -573,18 +571,7 @@ def execute_model(
all_gather_tensors=all_gather_tensors,
)

kv_connector_output = output.kv_connector_output
if not kv_connector_output:
return None

# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT

output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
return None

def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.model_runner.take_draft_token_ids()
Expand Down