Skip to content

Commit 12c1287

Browse files
[mypy] Further improve MM type annotations (vllm-project#25654)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 17b4c66 commit 12c1287

File tree

6 files changed

+90
-48
lines changed

6 files changed

+90
-48
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,12 @@ def apply(
415415
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
416416
num_image_patches),
417417
)
418+
418419
# Use overrides if provided; fallback to data-dependent hashing.
419-
mm_hashes = (mm_uuids if mm_uuids is not None else self._hash_mm_items(
420-
mm_items, hf_processor_mm_kwargs, tokenization_kwargs))
420+
mm_hashes = self._hash_mm_items(mm_items,
421+
hf_processor_mm_kwargs,
422+
tokenization_kwargs,
423+
mm_uuids=mm_uuids)
421424

422425
return MultiModalInputs(
423426
type="multimodal",

vllm/multimodal/inputs.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
1515

1616
from vllm.utils import LazyLoader, full_groupby, is_list_of
17-
from vllm.utils.jsontree import JSONTree, json_map_leaves
17+
from vllm.utils.jsontree import json_map_leaves
1818

1919
if TYPE_CHECKING:
2020
import torch
@@ -203,7 +203,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
203203
return a == b
204204

205205

206-
BatchedTensorInputs: TypeAlias = Mapping[str, NestedTensors]
206+
BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]
207207
"""
208208
A dictionary containing nested tensors which have been batched via
209209
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
@@ -377,6 +377,7 @@ def _reduce_data(
377377
pin_memory: bool,
378378
) -> NestedTensors:
379379
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
380+
batch = cast(list[torch.Tensor], batch)
380381
if len(batch) == 1:
381382
# An optimization when `batch` contains only one tensor:
382383
# - produce exactly same result as `torch.stack(batch)`
@@ -422,6 +423,7 @@ def _reduce_data(
422423
pin_memory: bool,
423424
) -> NestedTensors:
424425
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
426+
batch = cast(list[torch.Tensor], batch)
425427
if len(batch) == 1:
426428
# An optimization when `batch` contains only one tensor:
427429
# - produce exactly same result as `torch.concat(batch)`
@@ -764,6 +766,15 @@ def __getitem__(self, modality: str) -> Sequence[_I]:
764766

765767
return super().__getitem__(modality) # type: ignore[return-value]
766768

769+
def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
770+
for modality, items in self.items():
771+
for i, item in enumerate(items):
772+
if item is None:
773+
raise RuntimeError(
774+
f"Found empty mm_items[{modality}][{i}]")
775+
776+
return self # type: ignore[return-value]
777+
767778
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
768779
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
769780
for modality, items in self.items():
@@ -897,15 +908,11 @@ def as_kwargs(
897908
*,
898909
device: torch.types.Device,
899910
) -> BatchedTensorInputs:
900-
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
901-
902-
json_mapped = json_map_leaves(
911+
return json_map_leaves(
903912
lambda x: x.to(device=device, non_blocking=True),
904-
json_inputs,
913+
batched_inputs,
905914
)
906915

907-
return cast(BatchedTensorInputs, json_mapped)
908-
909916
def __getitem__(self, key: str):
910917
if key not in self:
911918
raise KeyError(f"Keyword argument {key!r} not found. "

vllm/multimodal/processing.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,7 +1585,7 @@ def _hash_mm_items(
15851585
*,
15861586
mm_uuids: Optional[MultiModalUUIDDict] = None,
15871587
) -> MultiModalHashes:
1588-
"""Create MM hashes to be returned (only used in V1).
1588+
"""Create MM hashes to be returned.
15891589
15901590
15911591
Note: When overrides are provided via callers of `apply`,
@@ -2098,23 +2098,22 @@ def _get_enc_dec_inputs(
20982098
encoder_inputs: MultiModalInputs,
20992099
):
21002100
tokenizer = self.info.get_tokenizer()
2101-
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
2102-
if isinstance(decoder_prompt, str):
2101+
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
2102+
if isinstance(decoder_prompt_raw, str):
2103+
decoder_prompt = decoder_prompt_raw
21032104
decoder_prompt_ids = encode_tokens(tokenizer,
2104-
decoder_prompt,
2105+
decoder_prompt_raw,
21052106
add_special_tokens=False)
21062107
else:
2107-
decoder_prompt_ids = decoder_prompt
2108-
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
2108+
decoder_prompt = decode_tokens(tokenizer, decoder_prompt_raw)
2109+
decoder_prompt_ids = decoder_prompt_raw
21092110

21102111
mm_inputs = MultiModalEncDecInputs(
21112112
encoder_prompt=encoder_inputs["prompt"],
21122113
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
21132114
**encoder_inputs)
2114-
mm_inputs.update({
2115-
"prompt": decoder_prompt,
2116-
"prompt_token_ids": decoder_prompt_ids
2117-
})
2115+
mm_inputs["prompt"] = decoder_prompt
2116+
mm_inputs["prompt_token_ids"] = decoder_prompt_ids
21182117
return mm_inputs
21192118

21202119
def apply(

vllm/multimodal/profiling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.logger import init_logger
1414

1515
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
16-
MultiModalInputs, MultiModalKwargsOptionalItems,
16+
MultiModalInputs, MultiModalKwargsItems,
1717
MultiModalPlaceholderDict)
1818
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
1919
EncDecMultiModalProcessor)
@@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
4343
"""Dummy data used for profiling."""
4444

4545
prompt_token_ids: list[int]
46-
multi_modal_data: MultiModalKwargsOptionalItems
46+
multi_modal_data: MultiModalKwargsItems
4747
multi_modal_placeholders: MultiModalPlaceholderDict
4848

4949

@@ -239,7 +239,7 @@ def get_decoder_dummy_data(
239239

240240
return DummyDecoderData(
241241
prompt_token_ids=prompt_token_ids,
242-
multi_modal_data=mm_inputs["mm_kwargs"],
242+
multi_modal_data=mm_inputs["mm_kwargs"].require_data(),
243243
multi_modal_placeholders=mm_inputs["mm_placeholders"],
244244
)
245245

vllm/multimodal/utils.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import vllm.envs as envs
2121
from vllm.connections import HTTPConnection, global_http_connection
22+
from vllm.utils.jsontree import json_map_leaves
2223

2324
from .audio import AudioMediaIO
2425
from .base import MediaIO
@@ -383,6 +384,7 @@ def group_mm_kwargs_by_modality(
383384
*,
384385
device: torch.types.Device = None,
385386
pin_memory: bool = False,
387+
merge_by_field_config: bool = False,
386388
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
387389
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
388390
modality together into the same `MultiModalKwargs` instance.
@@ -400,29 +402,31 @@ def group_mm_kwargs_by_modality(
400402
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
401403
items_lst = list(items)
402404

403-
# mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \
404-
# .get_data(pin_memory=pin_memory)
405-
406-
# if device is not None:
407-
# mm_kwargs_group = json_map_leaves(
408-
# lambda x: x.to(device=device),
409-
# mm_kwargs_group,
410-
# )
411-
412-
# TODO: Once V0 is removed, we can use the merging logic above
405+
# TODO: Enable `merge_by_field_config` for all models
413406
# to avoid creating an extra batch dimension (except for fields
414407
# that are meant to be stacked anyway).
415408
# We will also need to update each model to remove `flatten_bn`.
416-
mm_kwargs_group = MultiModalKwargs.as_kwargs(
417-
MultiModalKwargs.batch(
418-
[
419-
MultiModalKwargsItems.from_seq([item]).get_data()
420-
for item in items_lst
421-
],
422-
pin_memory=pin_memory,
423-
),
424-
device=device,
425-
)
409+
if merge_by_field_config:
410+
mm_kwargs_group: BatchedTensorInputs = dict(
411+
MultiModalKwargsItems.from_seq(items_lst).get_data(
412+
pin_memory=pin_memory))
413+
414+
if device is not None:
415+
mm_kwargs_group = json_map_leaves(
416+
lambda x: x.to(device=device),
417+
mm_kwargs_group,
418+
)
419+
else:
420+
mm_kwargs_group = MultiModalKwargs.as_kwargs(
421+
MultiModalKwargs.batch(
422+
[
423+
MultiModalKwargsItems.from_seq([item]).get_data()
424+
for item in items_lst
425+
],
426+
pin_memory=pin_memory,
427+
),
428+
device=device,
429+
)
426430

427431
yield modality, len(items_lst), mm_kwargs_group
428432

vllm/utils/jsontree.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55
from collections.abc import Iterable
66
from functools import reduce
7-
from typing import Callable, TypeVar, Union, cast, overload
7+
from typing import TYPE_CHECKING, Callable, TypeVar, Union, cast, overload
8+
9+
if TYPE_CHECKING:
10+
import torch
11+
12+
from vllm.multimodal.inputs import BatchedTensorInputs
813

914
_T = TypeVar("_T")
1015
_U = TypeVar("_U")
@@ -17,6 +22,19 @@
1722
]
1823
"""A nested JSON structure where the leaves need not be JSON-serializable."""
1924

25+
_JSONTree = Union[
26+
dict[str, "JSONTree[_T]"],
27+
list["JSONTree[_T]"],
28+
tuple["JSONTree[_T]", ...],
29+
dict[str, _T],
30+
list[_T],
31+
tuple[_T, ...],
32+
_T,
33+
]
34+
"""
35+
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
36+
"""
37+
2038

2139
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
2240
"""Iterate through each leaf in a nested JSON structure."""
@@ -30,6 +48,14 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
3048
yield value
3149

3250

51+
@overload
52+
def json_map_leaves(
53+
func: Callable[["torch.Tensor"], "torch.Tensor"],
54+
value: "BatchedTensorInputs",
55+
) -> "BatchedTensorInputs":
56+
...
57+
58+
3359
@overload
3460
def json_map_leaves(
3561
func: Callable[[_T], _U],
@@ -64,11 +90,14 @@ def json_map_leaves(
6490

6591
def json_map_leaves(
6692
func: Callable[[_T], _U],
67-
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
68-
) -> Union[dict[str, _U], list[_U], tuple[_U, ...], JSONTree[_U]]:
93+
value: Union["BatchedTensorInputs", _JSONTree[_T]],
94+
) -> Union["BatchedTensorInputs", _JSONTree[_U]]:
6995
"""Apply a function to each leaf in a nested JSON structure."""
7096
if isinstance(value, dict):
71-
return {k: json_map_leaves(func, v) for k, v in value.items()}
97+
return {
98+
k: json_map_leaves(func, v) # type: ignore[arg-type]
99+
for k, v in value.items()
100+
}
72101
elif isinstance(value, list):
73102
return [json_map_leaves(func, v) for v in value]
74103
elif isinstance(value, tuple):
@@ -125,7 +154,7 @@ def json_reduce_leaves(
125154

126155
def json_reduce_leaves(
127156
func: Callable[..., Union[_T, _U]],
128-
value: Union[dict[str, _T], list[_T], tuple[_T, ...], JSONTree[_T]],
157+
value: _JSONTree[_T],
129158
initial: _U = cast(_U, ...), # noqa: B008
130159
/,
131160
) -> Union[_T, _U]:

0 commit comments

Comments
 (0)