1616from vllm .transformers_utils .tokenizer import AnyTokenizer
1717
1818from .data import (DecoderOnlyInputs , EmbedsInputs , EmbedsPrompt ,
19- EncoderDecoderInputs , ProcessorInputs , PromptType ,
20- SingletonInputs , SingletonPrompt , TextPrompt , TokenInputs ,
21- TokensPrompt , embeds_inputs , token_inputs )
19+ EncoderDecoderInputs , ExplicitEncoderDecoderPrompt ,
20+ ProcessorInputs , PromptType , SingletonInputs ,
21+ SingletonPrompt , TextPrompt , TokenInputs , TokensPrompt ,
22+ embeds_inputs , token_inputs )
2223from .parse import is_explicit_encoder_decoder_prompt , parse_singleton_prompt
2324
2425logger = init_logger (__name__ )
@@ -322,7 +323,7 @@ def _process_tokens(
322323 mm_uuids = mm_uuids ,
323324 )
324325 else :
325- inputs = token_inputs (prompt_token_ids = prompt_token_ids )
326+ inputs = token_inputs (prompt_token_ids )
326327
327328 if cache_salt := parsed_content .get ("cache_salt" ):
328329 inputs ["cache_salt" ] = cache_salt
@@ -352,10 +353,7 @@ def _process_text(
352353 prompt_text ,
353354 tokenization_kwargs = tokenization_kwargs ,
354355 )
355- inputs = token_inputs (
356- prompt = prompt_text ,
357- prompt_token_ids = prompt_token_ids ,
358- )
356+ inputs = token_inputs (prompt_token_ids )
359357
360358 if cache_salt := parsed_content .get ("cache_salt" ):
361359 inputs ["cache_salt" ] = cache_salt
@@ -473,22 +471,17 @@ def _split_enc_dec_mm_inputs(
473471 decoder_inputs : SingletonInputs
474472
475473 if inputs ["type" ] == "multimodal" : # Multimodal data inputs
476- if not ("encoder_prompt" in inputs
477- and "encoder_prompt_token_ids" in inputs ):
474+ if "encoder_prompt_token_ids" not in inputs :
478475 raise RuntimeError ("You should register an encoder-decoder "
479476 "multi-modal processor for encoder-decoder "
480477 "models." )
481478 inputs = cast (MultiModalEncDecInputs , inputs )
482479
483- encoder_inputs = token_inputs (
484- prompt = inputs ["encoder_prompt" ],
485- prompt_token_ids = inputs ["encoder_prompt_token_ids" ],
486- )
480+ encoder_inputs = token_inputs (inputs ["encoder_prompt_token_ids" ])
487481
488482 decoder_prompt_inputs = decoder_inputs_to_override or inputs
489483 decoder_inputs = MultiModalInputs (
490484 type = "multimodal" ,
491- prompt = decoder_prompt_inputs .get ("prompt" , "" ),
492485 prompt_token_ids = decoder_prompt_inputs ["prompt_token_ids" ],
493486 mm_kwargs = inputs ["mm_kwargs" ],
494487 mm_hashes = inputs ["mm_hashes" ],
@@ -498,7 +491,7 @@ def _split_enc_dec_mm_inputs(
498491 decoder_inputs ["cache_salt" ] = cache_salt
499492
500493 elif inputs ["type" ] == "token" : # Text-only inputs
501- encoder_inputs = token_inputs (prompt = "" , prompt_token_ids = [])
494+ encoder_inputs = token_inputs (prompt_token_ids = [])
502495 decoder_inputs = decoder_inputs_to_override or inputs
503496 else :
504497 assert_never (inputs ) # type: ignore[arg-type]
@@ -549,12 +542,14 @@ def _process_encoder_decoder_prompt(
549542 decoder_inputs : Optional [SingletonInputs ]
550543
551544 if is_explicit_encoder_decoder_prompt (prompt ):
545+ # `cast` is needed for mypy, but not pyright
546+ prompt_ = cast (ExplicitEncoderDecoderPrompt , prompt )
552547 encoder_inputs = self ._prompt_to_llm_inputs (
553- prompt ["encoder_prompt" ],
548+ prompt_ ["encoder_prompt" ],
554549 tokenization_kwargs = tokenization_kwargs ,
555550 mm_uuids = mm_uuids ,
556551 )
557- if (decoder_input := prompt ["decoder_prompt" ]) is None :
552+ if (decoder_input := prompt_ ["decoder_prompt" ]) is None :
558553 decoder_inputs = None
559554 else :
560555 decoder_inputs = self ._prompt_to_llm_inputs (decoder_input )
@@ -565,8 +560,9 @@ def _process_encoder_decoder_prompt(
565560 self ._split_enc_dec_mm_inputs (encoder_inputs ,
566561 decoder_inputs ))
567562 else :
563+ # `cast` is needed for mypy, but not pyright
568564 inputs = self ._prompt_to_llm_inputs (
569- prompt ,
565+ cast ( SingletonPrompt , prompt ) ,
570566 tokenization_kwargs = tokenization_kwargs ,
571567 mm_uuids = mm_uuids ,
572568 )
@@ -641,8 +637,9 @@ def preprocess(
641637 "to decoder-only models" )
642638
643639 # Decoder-only operation
640+ # `cast` is needed for mypy, but not pyright
644641 return self ._process_decoder_only_prompt (
645- prompt ,
642+ cast ( SingletonPrompt , prompt ) ,
646643 tokenization_kwargs = tokenization_kwargs ,
647644 mm_uuids = mm_uuids ,
648645 )
0 commit comments