22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33from collections .abc import Mapping
44from dataclasses import dataclass
5- from typing import TYPE_CHECKING , Generic , Protocol , TypeVar
6-
7- import torch .nn as nn
5+ from typing import TYPE_CHECKING , Generic , Protocol , TypeVar , cast
86
97from vllm .config .multimodal import BaseDummyOptions
108from vllm .logger import init_logger
119from vllm .transformers_utils .tokenizer import AnyTokenizer , cached_tokenizer_from_config
12- from vllm .utils .collection_utils import ClassRegistry
1310
1411from .cache import BaseMultiModalProcessorCache
1512from .processing import (
2623
2724if TYPE_CHECKING :
2825 from vllm .config import ModelConfig
26+ from vllm .model_executor .models .interfaces import SupportsMultiModal
2927
3028logger = init_logger (__name__ )
3129
32- N = TypeVar ("N" , bound = type [nn . Module ])
30+ N = TypeVar ("N" , bound = type ["SupportsMultiModal" ])
3331_I = TypeVar ("_I" , bound = BaseProcessingInfo )
3432_I_co = TypeVar ("_I_co" , bound = BaseProcessingInfo , covariant = True )
3533
@@ -95,9 +93,6 @@ class MultiModalRegistry:
9593 A registry that dispatches data processing according to the model.
9694 """
9795
98- def __init__ (self ) -> None :
99- self ._processor_factories = ClassRegistry [nn .Module , _ProcessorFactories ]()
100-
10196 def _extract_mm_options (
10297 self ,
10398 model_config : "ModelConfig" ,
@@ -207,15 +202,15 @@ def register_processor(
207202 """
208203
209204 def wrapper (model_cls : N ) -> N :
210- if self . _processor_factories . contains ( model_cls , strict = True ) :
205+ if "_processor_factory" in model_cls . __dict__ :
211206 logger .warning (
212207 "Model class %s already has a multi-modal processor "
213208 "registered to %s. It is overwritten by the new one." ,
214209 model_cls ,
215210 self ,
216211 )
217212
218- self . _processor_factories [ model_cls ] = _ProcessorFactories (
213+ model_cls . _processor_factory = _ProcessorFactories (
219214 info = info ,
220215 dummy_inputs = dummy_inputs ,
221216 processor = processor ,
@@ -225,12 +220,13 @@ def wrapper(model_cls: N) -> N:
225220
226221 return wrapper
227222
228- def _get_model_cls (self , model_config : "ModelConfig" ):
223+ def _get_model_cls (self , model_config : "ModelConfig" ) -> "SupportsMultiModal" :
229224 # Avoid circular import
230225 from vllm .model_executor .model_loader import get_model_architecture
231226
232227 model_cls , _ = get_model_architecture (model_config )
233- return model_cls
228+ assert hasattr (model_cls , "_processor_factory" )
229+ return cast ("SupportsMultiModal" , model_cls )
234230
235231 def _create_processing_ctx (
236232 self ,
@@ -248,7 +244,7 @@ def _create_processing_info(
248244 tokenizer : AnyTokenizer | None = None ,
249245 ) -> BaseProcessingInfo :
250246 model_cls = self ._get_model_cls (model_config )
251- factories = self . _processor_factories [ model_cls ]
247+ factories = model_cls . _processor_factory
252248 ctx = self ._create_processing_ctx (model_config , tokenizer )
253249 return factories .info (ctx )
254250
@@ -266,7 +262,7 @@ def create_processor(
266262 raise ValueError (f"{ model_config .model } is not a multimodal model" )
267263
268264 model_cls = self ._get_model_cls (model_config )
269- factories = self . _processor_factories [ model_cls ]
265+ factories = model_cls . _processor_factory
270266
271267 ctx = self ._create_processing_ctx (model_config , tokenizer )
272268
0 commit comments