diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a6965182fc2c..9a6256623cdb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -63,7 +63,6 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder from vllm.v1.structured_output import StructuredOutputManager -from vllm.v1.utils import record_function_or_nullcontext from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -181,11 +180,13 @@ def __init__( logger.info("Batch queue is enabled with size %d", self.batch_queue_size) self.batch_queue = deque(maxlen=self.batch_queue_size) + self.ec_producer = ( + vllm_config.ec_transfer_config is not None + and vllm_config.ec_transfer_config.is_ec_producer + ) + self.request_block_hasher: Callable[[Request], list[BlockHash]] | None = None - if ( - self.vllm_config.cache_config.enable_prefix_caching - or kv_connector is not None - ): + if vllm_config.cache_config.enable_prefix_caching or kv_connector is not None: caching_hash_fn = get_hash_fn_by_name( vllm_config.cache_config.prefix_caching_hash_algo ) @@ -245,7 +246,7 @@ def _initialize_kv_caches( elapsed = time.time() - start logger.info_once( - ("init engine (profile, create kv cache, warmup model) took %.2f seconds"), + "init engine (profile, create kv cache, warmup model) took %.2f seconds", elapsed, scope="local", ) @@ -311,6 +312,16 @@ def log_error_detail(self, scheduler_output: SchedulerOutput): ) raise err + def _log_err_callback(self, scheduler_output: SchedulerOutput): + """Log error details of a future that's not expected to return a result.""" + + def callback(f, sched_output=scheduler_output): + with self.log_error_detail(sched_output): + result = f.result() + assert result is None + + return callback + def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: """Schedule, execute, and make output. @@ -322,21 +333,17 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: # or finished and not yet removed from the batch. if not self.scheduler.has_requests(): return {}, False - with record_function_or_nullcontext("core step: schedule"): - scheduler_output = self.scheduler.schedule() - - with record_function_or_nullcontext("core step: execute_model"): - future = self.model_executor.execute_model(scheduler_output, non_block=True) - grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) - with self.log_error_detail(scheduler_output): - model_output = future.result() - if model_output is None: - model_output = self.model_executor.sample_tokens(grammar_output) - - with record_function_or_nullcontext("core step: update_from_output"): - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + scheduler_output = self.scheduler.schedule() + future = self.model_executor.execute_model(scheduler_output, non_block=True) + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + with self.log_error_detail(scheduler_output): + model_output = future.result() + if model_output is None: + model_output = self.model_executor.sample_tokens(grammar_output) + + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0 @@ -374,52 +381,34 @@ def step_with_batch_queue( model_executed = False deferred_scheduler_output = None if self.scheduler.has_requests(): - with record_function_or_nullcontext("core step_with_batch_queue: schedule"): - scheduler_output = self.scheduler.schedule() - with record_function_or_nullcontext( - "core step_with_batch_queue: execute_model" - ): - exec_future = self.model_executor.execute_model( - scheduler_output, non_block=True - ) - model_executed = scheduler_output.total_num_scheduled_tokens > 0 + scheduler_output = self.scheduler.schedule() + exec_future = self.model_executor.execute_model( + scheduler_output, non_block=True + ) + if not self.ec_producer: + model_executed = scheduler_output.total_num_scheduled_tokens > 0 - if scheduler_output.pending_structured_output_tokens: - with record_function_or_nullcontext( - "core step_with_batch_queue: pending_structured_output_tokens" - ): - # We need to defer sampling until we have processed the model output - # from the prior step. - deferred_scheduler_output = scheduler_output - # Block-wait for execute to return - # (continues running async on the GPU). - with self.log_error_detail(scheduler_output): - exec_result = exec_future.result() - assert exec_result is None + if not model_executed: + # No sampling required (no requests scheduled). + future = cast(Future[ModelRunnerOutput], exec_future) else: - with record_function_or_nullcontext( - "core step_with_batch_queue: get_grammar_bitmask" - ): - # We aren't waiting for any tokens, get any grammar - # output immediately. + exec_future.add_done_callback(self._log_err_callback(scheduler_output)) + + if not scheduler_output.pending_structured_output_tokens: + # We aren't waiting for any tokens, get any grammar output + # and sample immediately. grammar_output = self.scheduler.get_grammar_bitmask( scheduler_output ) - # Block-wait for execute to return (continues running async on the GPU). - with self.log_error_detail(scheduler_output): - exec_result = exec_future.result() - - if exec_result is None: - with record_function_or_nullcontext( - "core step_with_batch_queue: sample_tokens" - ): - # Call sample tokens. - future = self.model_executor.sample_tokens( - grammar_output, non_block=True - ) + future = self.model_executor.sample_tokens( + grammar_output, non_block=True + ) else: - # No sampling required (e.g. all requests finished). - future = cast(Future[ModelRunnerOutput], exec_future) + # We need to defer sampling until we have processed the model output + # from the prior step. + deferred_scheduler_output = scheduler_output + + if not deferred_scheduler_output: # Add this step's future to the queue. batch_queue.appendleft((future, scheduler_output)) if ( @@ -436,34 +425,27 @@ def step_with_batch_queue( # only be called when the scheduler contains requests or the queue # is non-empty. return None, False - with record_function_or_nullcontext("core step_with_batch_queue: model_output"): - # Block until the next result is available. - future, scheduler_output = batch_queue.pop() - with self.log_error_detail(scheduler_output): - model_output = future.result() - with record_function_or_nullcontext( - "core step_with_batch_queue: update_from_output" - ): - engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, model_output - ) + + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + with self.log_error_detail(scheduler_output): + model_output = future.result() + + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output + ) # NOTE(nick): We can either handle the deferred tasks here or save # in a field and do it immediately once step_with_batch_queue is # re-called. The latter slightly favors TTFT over TPOT/throughput. if deferred_scheduler_output: - with record_function_or_nullcontext( - "core step_with_batch_queue: deferred_scheduler_output" - ): - # We now have the tokens needed to compute the bitmask for the - # deferred request. Get the bitmask and call sample tokens. - grammar_output = self.scheduler.get_grammar_bitmask( - deferred_scheduler_output - ) - future = self.model_executor.sample_tokens( - grammar_output, non_block=True - ) - batch_queue.appendleft((future, deferred_scheduler_output)) + # We now have the tokens needed to compute the bitmask for the + # deferred request. Get the bitmask and call sample tokens. + grammar_output = self.scheduler.get_grammar_bitmask( + deferred_scheduler_output + ) + future = self.model_executor.sample_tokens(grammar_output, non_block=True) + batch_queue.appendleft((future, deferred_scheduler_output)) return engine_core_outputs, model_executed diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 119e4c081831..55db7445c9c7 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -99,6 +99,11 @@ def _init_executor(self) -> None: # KV connector setup self.has_connector = self.vllm_config.kv_transfer_config is not None + self.ec_producer = ( + self.vllm_config.ec_transfer_config is not None + and self.vllm_config.ec_transfer_config.is_ec_producer + ) + self.scheduler_output: SchedulerOutput | None = None @property @@ -395,6 +400,12 @@ def execute_model( # type: ignore[override] "State error: sample_tokens() must be called " "after execute_model() returns None." ) + + if self.ec_producer or not scheduler_output.total_num_scheduled_tokens: + # Model will not execute, call model runner immediately. + return self._execute_dag(scheduler_output, None, non_block) + + # Model will execute, defer to sample_tokens() call. self.scheduler_output = scheduler_output return COMPLETED_NONE_FUTURE if non_block else None @@ -417,10 +428,18 @@ def sample_tokens( # type: ignore[override] """ scheduler_output = self.scheduler_output if scheduler_output is None: - return None # noqa + return COMPLETED_NONE_FUTURE if non_block else None # noqa self.scheduler_output = None + return self._execute_dag(scheduler_output, grammar_output, non_block) + + def _execute_dag( + self, + scheduler_output: SchedulerOutput, + grammar_output: "GrammarOutput | None", + non_block: bool = False, + ) -> ModelRunnerOutput | Future[ModelRunnerOutput]: # Build the compiled DAG for the first time. if self.forward_dag is None: # type: ignore self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6590ca54af68..a32f63ae4dd8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -7,7 +7,7 @@ from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager -from copy import deepcopy +from copy import copy, deepcopy from functools import reduce from itertools import product from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast @@ -242,7 +242,6 @@ class ExecuteModelState(NamedTuple): hidden_states: torch.Tensor sample_hidden_states: torch.Tensor aux_hidden_states: list[torch.Tensor] | None - kv_connector_output: KVConnectorOutput | None ec_connector_output: ECConnectorOutput | None @@ -551,6 +550,7 @@ def __init__( # Ephemeral state transferred between execute_model() and sample_tokens(). self.execute_model_state: ExecuteModelState | None = None + self.kv_connector_output: KVConnectorOutput | None = None def reset_mm_cache(self) -> None: if self.mm_budget: @@ -2682,6 +2682,7 @@ def execute_model( # Return the intermediate tensors. assert isinstance(hidden_states, IntermediateTensors) hidden_states.kv_connector_output = kv_connector_output + self.kv_connector_output = kv_connector_output return hidden_states if self.is_pooling_model: @@ -2732,18 +2733,31 @@ def execute_model( hidden_states, sample_hidden_states, aux_hidden_states, - kv_connector_output, ec_connector_output, ) + self.kv_connector_output = kv_connector_output return None @torch.inference_mode def sample_tokens( self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: + kv_connector_output = self.kv_connector_output + self.kv_connector_output = None + if self.execute_model_state is None: # Nothing to do (PP non-final rank case), output isn't used. - return None # noqa + if not kv_connector_output: + return None # noqa + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if kv_connector_output.is_empty(): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output + return output # Unpack ephemeral state. ( @@ -2754,7 +2768,6 @@ def sample_tokens( hidden_states, sample_hidden_states, aux_hidden_states, - kv_connector_output, ec_connector_output, ) = self.execute_model_state # Clear ephemeral state. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 283e3744bcf6..e24f44b3d1e8 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A GPU worker class.""" -import copy import gc import os from contextlib import AbstractContextManager, nullcontext @@ -45,7 +44,6 @@ from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ( - EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput, @@ -573,18 +571,7 @@ def execute_model( all_gather_tensors=all_gather_tensors, ) - kv_connector_output = output.kv_connector_output - if not kv_connector_output: - return None - - # In case of PP with kv transfer, we need to pass through the - # kv_connector_output - if kv_connector_output.is_empty(): - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output - return output + return None def take_draft_token_ids(self) -> DraftTokenIds | None: return self.model_runner.take_draft_token_ids()