Skip to content

Commit 7675ba3

Browse files
[Misc] Remove redundant ClassRegistry (#29681)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 7c1ed45 commit 7675ba3

File tree

5 files changed

+25
-49
lines changed

5 files changed

+25
-49
lines changed

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _test_processing_correctness(
233233
)
234234

235235
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
236-
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
236+
factories = model_cls._processor_factory
237237
ctx = InputProcessingContext(
238238
model_config,
239239
tokenizer=cached_tokenizer_from_config(model_config),

tests/models/multimodal/processing/test_tensor_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def test_model_tensor_schema(model_id: str):
193193
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
194194
assert supports_multimodal(model_cls)
195195

196-
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
196+
factories = model_cls._processor_factory
197197

198198
inputs_parse_methods = []
199199
for attr_name in dir(model_cls):

vllm/model_executor/models/interfaces.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@
3232
from vllm.config import VllmConfig
3333
from vllm.model_executor.models.utils import WeightsMapper
3434
from vllm.multimodal.inputs import MultiModalFeatureSpec
35+
from vllm.multimodal.registry import _ProcessorFactories
3536
from vllm.sequence import IntermediateTensors
3637
else:
3738
VllmConfig = object
3839
WeightsMapper = object
3940
MultiModalFeatureSpec = object
41+
_ProcessorFactories = object
4042
IntermediateTensors = object
4143

4244
logger = init_logger(__name__)
@@ -87,6 +89,11 @@ class SupportsMultiModal(Protocol):
8789
A set indicating CPU-only multimodal fields.
8890
"""
8991

92+
_processor_factory: ClassVar[_ProcessorFactories]
93+
"""
94+
Set internally by `MultiModalRegistry.register_processor`.
95+
"""
96+
9097
@classmethod
9198
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
9299
"""

vllm/multimodal/registry.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections.abc import Mapping
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
6-
7-
import torch.nn as nn
5+
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
86

97
from vllm.config.multimodal import BaseDummyOptions
108
from vllm.logger import init_logger
119
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
12-
from vllm.utils.collection_utils import ClassRegistry
1310

1411
from .cache import BaseMultiModalProcessorCache
1512
from .processing import (
@@ -26,10 +23,11 @@
2623

2724
if TYPE_CHECKING:
2825
from vllm.config import ModelConfig
26+
from vllm.model_executor.models.interfaces import SupportsMultiModal
2927

3028
logger = init_logger(__name__)
3129

32-
N = TypeVar("N", bound=type[nn.Module])
30+
N = TypeVar("N", bound=type["SupportsMultiModal"])
3331
_I = TypeVar("_I", bound=BaseProcessingInfo)
3432
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
3533

@@ -95,9 +93,6 @@ class MultiModalRegistry:
9593
A registry that dispatches data processing according to the model.
9694
"""
9795

98-
def __init__(self) -> None:
99-
self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]()
100-
10196
def _extract_mm_options(
10297
self,
10398
model_config: "ModelConfig",
@@ -207,15 +202,15 @@ def register_processor(
207202
"""
208203

209204
def wrapper(model_cls: N) -> N:
210-
if self._processor_factories.contains(model_cls, strict=True):
205+
if "_processor_factory" in model_cls.__dict__:
211206
logger.warning(
212207
"Model class %s already has a multi-modal processor "
213208
"registered to %s. It is overwritten by the new one.",
214209
model_cls,
215210
self,
216211
)
217212

218-
self._processor_factories[model_cls] = _ProcessorFactories(
213+
model_cls._processor_factory = _ProcessorFactories(
219214
info=info,
220215
dummy_inputs=dummy_inputs,
221216
processor=processor,
@@ -225,12 +220,13 @@ def wrapper(model_cls: N) -> N:
225220

226221
return wrapper
227222

228-
def _get_model_cls(self, model_config: "ModelConfig"):
223+
def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
229224
# Avoid circular import
230225
from vllm.model_executor.model_loader import get_model_architecture
231226

232227
model_cls, _ = get_model_architecture(model_config)
233-
return model_cls
228+
assert hasattr(model_cls, "_processor_factory")
229+
return cast("SupportsMultiModal", model_cls)
234230

235231
def _create_processing_ctx(
236232
self,
@@ -248,7 +244,7 @@ def _create_processing_info(
248244
tokenizer: AnyTokenizer | None = None,
249245
) -> BaseProcessingInfo:
250246
model_cls = self._get_model_cls(model_config)
251-
factories = self._processor_factories[model_cls]
247+
factories = model_cls._processor_factory
252248
ctx = self._create_processing_ctx(model_config, tokenizer)
253249
return factories.info(ctx)
254250

@@ -266,7 +262,7 @@ def create_processor(
266262
raise ValueError(f"{model_config.model} is not a multimodal model")
267263

268264
model_cls = self._get_model_cls(model_config)
269-
factories = self._processor_factories[model_cls]
265+
factories = model_cls._processor_factory
270266

271267
ctx = self._create_processing_ctx(model_config, tokenizer)
272268

vllm/utils/collection_utils.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,64 +6,37 @@
66
This is similar in concept to the `collections` module.
77
"""
88

9-
from collections import UserDict, defaultdict
9+
from collections import defaultdict
1010
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
1111
from typing import Generic, Literal, TypeVar
1212

1313
from typing_extensions import TypeIs, assert_never
1414

1515
T = TypeVar("T")
16-
U = TypeVar("U")
1716

1817
_K = TypeVar("_K", bound=Hashable)
1918
_V = TypeVar("_V")
2019

2120

22-
class ClassRegistry(UserDict[type[T], _V]):
23-
"""
24-
A registry that acts like a dictionary but searches for other classes
25-
in the MRO if the original class is not found.
26-
"""
27-
28-
def __getitem__(self, key: type[T]) -> _V:
29-
for cls in key.mro():
30-
if cls in self.data:
31-
return self.data[cls]
32-
33-
raise KeyError(key)
34-
35-
def __contains__(self, key: object) -> bool:
36-
return self.contains(key)
37-
38-
def contains(self, key: object, *, strict: bool = False) -> bool:
39-
if not isinstance(key, type):
40-
return False
41-
42-
if strict:
43-
return key in self.data
44-
45-
return any(cls in self.data for cls in key.mro())
46-
47-
48-
class LazyDict(Mapping[str, T], Generic[T]):
21+
class LazyDict(Mapping[str, _V], Generic[_V]):
4922
"""
5023
Evaluates dictionary items only when they are accessed.
5124
5225
Adapted from: https://stackoverflow.com/a/47212782/5082708
5326
"""
5427

55-
def __init__(self, factory: dict[str, Callable[[], T]]):
28+
def __init__(self, factory: dict[str, Callable[[], _V]]):
5629
self._factory = factory
57-
self._dict: dict[str, T] = {}
30+
self._dict: dict[str, _V] = {}
5831

59-
def __getitem__(self, key: str) -> T:
32+
def __getitem__(self, key: str) -> _V:
6033
if key not in self._dict:
6134
if key not in self._factory:
6235
raise KeyError(key)
6336
self._dict[key] = self._factory[key]()
6437
return self._dict[key]
6538

66-
def __setitem__(self, key: str, value: Callable[[], T]):
39+
def __setitem__(self, key: str, value: Callable[[], _V]):
6740
self._factory[key] = value
6841

6942
def __iter__(self):

0 commit comments

Comments
 (0)