Skip to content

Commit d78fda7

Browse files
[Renderer] Move Processor out of LLMEngine (vllm-project#26165)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 73a99cc commit d78fda7

File tree

4 files changed

+107
-52
lines changed

4 files changed

+107
-52
lines changed

vllm/entrypoints/llm.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
log_non_default_args)
3838
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
3939
TokensPrompt)
40+
from vllm.inputs.parse import get_prompt_components
4041
from vllm.logger import init_logger
4142
from vllm.lora.request import LoRARequest
4243
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -49,10 +50,13 @@
4950
SamplingParams)
5051
from vllm.tasks import PoolingTask
5152
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
52-
get_cached_tokenizer)
53+
get_cached_tokenizer,
54+
init_tokenizer_from_configs)
5355
from vllm.usage.usage_lib import UsageContext
5456
from vllm.utils import Counter, Device, as_iter, is_list_of
57+
from vllm.v1.engine import EngineCoreRequest
5558
from vllm.v1.engine.llm_engine import LLMEngine
59+
from vllm.v1.engine.processor import Processor
5660
from vllm.v1.sample.logits_processor import LogitsProcessor
5761

5862
if TYPE_CHECKING:
@@ -312,6 +316,10 @@ def __init__(
312316
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
313317
io_processor_plugin)
314318

319+
@property
320+
def model_config(self):
321+
return self.llm_engine.model_config
322+
315323
def get_tokenizer(self) -> AnyTokenizer:
316324
return self.llm_engine.get_tokenizer()
317325

@@ -324,6 +332,16 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
324332
else:
325333
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
326334

335+
def _get_processor(self) -> Processor:
336+
if not hasattr(self, "_processor"):
337+
vllm_config = self.llm_engine.vllm_config
338+
if self.model_config.skip_tokenizer_init:
339+
tokenizer = None
340+
else:
341+
tokenizer = init_tokenizer_from_configs(self.model_config)
342+
self._processor = Processor(vllm_config, tokenizer)
343+
return self._processor
344+
327345
def get_default_sampling_params(self) -> SamplingParams:
328346
if self.default_sampling_params is None:
329347
self.default_sampling_params = (
@@ -1497,26 +1515,16 @@ def _validate_and_add_requests(
14971515
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
14981516
it = tqdm_func(it, desc="Adding requests")
14991517

1500-
model_config = self.llm_engine.model_config
1501-
15021518
for i, prompt in enumerate(it):
15031519

15041520
if isinstance(prompt, dict):
15051521
self._validate_mm_data_and_uuids(
15061522
prompt.get("multi_modal_data"),
15071523
prompt.get("multi_modal_uuids"))
15081524

1509-
param = params[i] if isinstance(params, Sequence) else params
1510-
1511-
tokenization_kwargs: dict[str, Any] = {}
1512-
_validate_truncation_size(model_config.max_model_len,
1513-
param.truncate_prompt_tokens,
1514-
tokenization_kwargs)
1515-
15161525
self._add_request(
15171526
prompt,
15181527
params[i] if isinstance(params, Sequence) else params,
1519-
tokenization_kwargs=tokenization_kwargs,
15201528
lora_request=lora_request[i] if isinstance(
15211529
lora_request, Sequence) else lora_request,
15221530
priority=priority[i] if priority else 0,
@@ -1557,22 +1565,58 @@ def _validate_mm_data_and_uuids(
15571565
raise ValueError(f"Multi-modal data for {modality} is None"
15581566
f" but UUID is not provided")
15591567

1568+
def _process_inputs(
1569+
self,
1570+
request_id: str,
1571+
engine_prompt: PromptType,
1572+
params: Union[SamplingParams, PoolingParams],
1573+
*,
1574+
lora_request: Optional[LoRARequest],
1575+
priority: int,
1576+
) -> tuple[EngineCoreRequest, dict[str, Any]]:
1577+
"""Use the Processor to process inputs for LLMEngine."""
1578+
tokenization_kwargs: dict[str, Any] = {}
1579+
_validate_truncation_size(self.model_config.max_model_len,
1580+
params.truncate_prompt_tokens,
1581+
tokenization_kwargs)
1582+
1583+
processor = self._get_processor()
1584+
engine_request = processor.process_inputs(
1585+
request_id,
1586+
engine_prompt,
1587+
params,
1588+
lora_request=lora_request,
1589+
tokenization_kwargs=tokenization_kwargs,
1590+
priority=priority,
1591+
)
1592+
return engine_request, tokenization_kwargs
1593+
15601594
def _add_request(
15611595
self,
15621596
prompt: PromptType,
15631597
params: Union[SamplingParams, PoolingParams],
1564-
tokenization_kwargs: Optional[dict[str, Any]] = None,
15651598
lora_request: Optional[LoRARequest] = None,
15661599
priority: int = 0,
15671600
) -> None:
1601+
prompt_text, _, _ = get_prompt_components(prompt)
15681602
request_id = str(next(self.request_counter))
1569-
self.llm_engine.add_request(
1603+
1604+
engine_request, tokenization_kwargs = self._process_inputs(
15701605
request_id,
15711606
prompt,
15721607
params,
15731608
lora_request=lora_request,
1609+
priority=priority,
1610+
)
1611+
1612+
self.llm_engine.add_request(
1613+
request_id,
1614+
engine_request,
1615+
params,
1616+
lora_request=lora_request,
15741617
tokenization_kwargs=tokenization_kwargs,
15751618
priority=priority,
1619+
prompt_text=prompt_text,
15761620
)
15771621

15781622
def _run_engine(

vllm/entrypoints/openai/serving_engine.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
88
from concurrent.futures import ThreadPoolExecutor
99
from http import HTTPStatus
10-
from typing import (Any, Callable, ClassVar, Generic, NamedTuple, Optional,
11-
TypeVar, Union)
10+
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
1211

1312
import torch
1413
from fastapi import Request
@@ -69,6 +68,7 @@
6968
# yapf: enable
7069
from vllm.inputs.data import PromptType
7170
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
71+
from vllm.inputs.parse import PromptComponents, get_prompt_components
7272
from vllm.logger import init_logger
7373
from vllm.logprobs import Logprob, PromptLogprobs
7474
from vllm.lora.request import LoRARequest
@@ -140,12 +140,6 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
140140
and "prompt_embeds" in prompt)
141141

142142

143-
class PromptComponents(NamedTuple):
144-
text: Optional[str] = None
145-
token_ids: Optional[list[int]] = None
146-
embeds: Optional[torch.Tensor] = None
147-
148-
149143
RequestT = TypeVar("RequestT", bound=AnyRequest)
150144

151145

@@ -876,25 +870,23 @@ async def _process_inputs(
876870
self,
877871
request_id: str,
878872
engine_prompt: PromptType,
879-
sampling_params: SamplingParams,
873+
params: Union[SamplingParams, PoolingParams],
880874
*,
881875
lora_request: Optional[LoRARequest],
882876
trace_headers: Optional[Mapping[str, str]],
883877
priority: int,
884878
) -> tuple[EngineCoreRequest, dict[str, Any]]:
885-
"""
886-
using the Processor to process inputs for AsyncLLM
887-
"""
879+
"""Use the Processor to process inputs for AsyncLLM."""
888880
tokenization_kwargs: dict[str, Any] = {}
889881
_validate_truncation_size(self.max_model_len,
890-
sampling_params.truncate_prompt_tokens,
882+
params.truncate_prompt_tokens,
891883
tokenization_kwargs)
892884

893885
processor = await self._get_processor()
894886
engine_request = processor.process_inputs(
895887
request_id,
896888
engine_prompt,
897-
sampling_params,
889+
params,
898890
lora_request=lora_request,
899891
tokenization_kwargs=tokenization_kwargs,
900892
trace_headers=trace_headers,
@@ -973,25 +965,12 @@ async def _generate_with_builtin_tools(
973965

974966
def _get_prompt_components(
975967
self,
976-
inputs: Union[RequestPrompt, PromptType],
968+
prompt: Union[RequestPrompt, PromptType],
977969
) -> PromptComponents:
978-
if isinstance(inputs, str):
979-
return PromptComponents(text=inputs)
980-
if isinstance(inputs, list):
981-
return PromptComponents(token_ids=inputs)
982-
if isinstance(inputs, dict):
983-
return PromptComponents(
984-
text=inputs.get("prompt"), # type: ignore[arg-type]
985-
token_ids=inputs.get(
986-
"prompt_token_ids"), # type: ignore[arg-type]
987-
embeds=inputs.get("prompt_embeds"),
988-
)
970+
if isinstance(prompt, list):
971+
return PromptComponents(token_ids=prompt)
989972

990-
return PromptComponents(
991-
text=getattr(inputs, "prompt", None),
992-
token_ids=getattr(inputs, "prompt_token_ids", None),
993-
embeds=getattr(inputs, "prompt_embeds", None),
994-
)
973+
return get_prompt_components(prompt) # type: ignore[arg-type]
995974

996975
def _log_inputs(
997976
self,

vllm/inputs/parse.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Sequence
4-
from typing import Literal, Optional, TypedDict, Union, cast, overload
4+
from typing import (TYPE_CHECKING, Literal, NamedTuple, Optional, TypedDict,
5+
Union, cast, overload)
56

67
from typing_extensions import TypeIs
78

@@ -11,6 +12,9 @@
1112
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
1213
TokensPrompt)
1314

15+
if TYPE_CHECKING:
16+
import torch
17+
1418

1519
class ParsedText(TypedDict):
1620
content: str
@@ -149,3 +153,23 @@ def split_enc_dec_inputs(
149153
)
150154

151155
return None, inputs
156+
157+
158+
class PromptComponents(NamedTuple):
159+
text: Optional[str] = None
160+
token_ids: Optional[list[int]] = None
161+
embeds: Optional["torch.Tensor"] = None
162+
163+
164+
def get_prompt_components(prompt: PromptType) -> PromptComponents:
165+
if isinstance(prompt, str):
166+
return PromptComponents(text=prompt)
167+
168+
if (encoder_prompt := prompt.get("encoder_prompt")):
169+
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
170+
171+
return PromptComponents(
172+
text=prompt.get("prompt"), # type: ignore[arg-type]
173+
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
174+
embeds=prompt.get("prompt_embeds"),
175+
)

vllm/v1/engine/llm_engine.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
init_tokenizer_from_configs)
2828
from vllm.usage.usage_lib import UsageContext
2929
from vllm.utils import Device
30+
from vllm.v1.engine import EngineCoreRequest
3031
from vllm.v1.engine.core_client import EngineCoreClient
3132
from vllm.v1.engine.output_processor import OutputProcessor
3233
from vllm.v1.engine.parallel_sampling import ParentRequest
@@ -213,26 +214,33 @@ def abort_request(self, request_ids: list[str]) -> None:
213214
def add_request(
214215
self,
215216
request_id: str,
216-
prompt: PromptType,
217+
prompt: Union[EngineCoreRequest, PromptType],
217218
params: Union[SamplingParams, PoolingParams],
218219
arrival_time: Optional[float] = None,
219220
lora_request: Optional[LoRARequest] = None,
220221
tokenization_kwargs: Optional[dict[str, Any]] = None,
221222
trace_headers: Optional[Mapping[str, str]] = None,
222223
priority: int = 0,
224+
prompt_text: Optional[str] = None,
223225
) -> None:
224226
# Validate the request_id type.
225227
if not isinstance(request_id, str):
226228
raise TypeError(
227229
f"request_id must be a string, got {type(request_id)}")
228230

229231
# Process raw inputs into the request.
230-
request = self.processor.process_inputs(request_id, prompt, params,
231-
arrival_time, lora_request,
232-
tokenization_kwargs,
233-
trace_headers, priority)
234-
prompt_text = prompt if isinstance(prompt,
235-
str) else prompt.get("prompt")
232+
if isinstance(prompt, EngineCoreRequest):
233+
request = prompt
234+
else:
235+
assert prompt_text is None
236+
logger.warning_once("Processor has been moved under LLM and will "
237+
"be removed from LLMEngine in v0.13.")
238+
request = self.processor.process_inputs(request_id, prompt, params,
239+
arrival_time, lora_request,
240+
tokenization_kwargs,
241+
trace_headers, priority)
242+
prompt_text = (prompt if isinstance(prompt, str) else
243+
prompt.get("prompt"))
236244

237245
n = params.n if isinstance(params, SamplingParams) else 1
238246

0 commit comments

Comments
 (0)