Skip to content

Commit cb15ee2

Browse files
Allow Gemma3 to take image embeddings (#28483)
Signed-off-by: tingtinggithub <streamttt@gmail.com>
1 parent f36292d commit cb15ee2

File tree

4 files changed

+69
-29
lines changed

4 files changed

+69
-29
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
669669
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I<sup>+</sup> | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
670670
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
671671
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
672-
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
672+
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>E+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
673673
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
674674
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
675675
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |

vllm/model_executor/models/gemma3_mm.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import math
44
from collections.abc import Iterable, Mapping, Sequence
5-
from typing import Annotated, Any, Literal
5+
from typing import Annotated, Any, Literal, TypeAlias
66

77
import torch
88
from torch import nn
@@ -20,7 +20,12 @@
2020
MultiModalFieldConfig,
2121
MultiModalKwargsItems,
2222
)
23-
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
23+
from vllm.multimodal.parse import (
24+
ImageEmbeddingItems,
25+
ImageProcessorItems,
26+
ImageSize,
27+
MultiModalDataItems,
28+
)
2429
from vllm.multimodal.processing import (
2530
BaseMultiModalProcessor,
2631
BaseProcessingInfo,
@@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
7176
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
7277

7378

74-
Gemma3ImageInputs = Gemma3ImagePixelInputs
79+
class Gemma3ImageEmbeddingInputs(TensorSchema):
80+
type: Literal["image_embeds"] = "image_embeds"
81+
image_embeds: Annotated[
82+
torch.Tensor,
83+
TensorShape("ni", "nf", "hs"),
84+
]
85+
86+
87+
Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
7588

7689

7790
class Gemma3ProcessingInfo(BaseProcessingInfo):
@@ -178,20 +191,23 @@ def get_num_crops(
178191
def get_image_repl(
179192
self,
180193
*,
181-
image_width: int,
182-
image_height: int,
194+
image_width: int | None,
195+
image_height: int | None,
196+
num_crops: int | None = None,
183197
processor: Gemma3Processor | None,
184198
) -> PromptUpdateDetails[str]:
185199
if processor is None:
186200
processor = self.get_hf_processor()
187201

188202
boi_token = processor.boi_token
189203

190-
num_crops = self.get_num_crops(
191-
image_width=image_width,
192-
image_height=image_height,
193-
processor=processor,
194-
)
204+
if num_crops is None:
205+
assert image_width is not None and image_height is not None
206+
num_crops = self.get_num_crops(
207+
image_width=image_width,
208+
image_height=image_height,
209+
processor=processor,
210+
)
195211

196212
if num_crops == 0:
197213
image_text = boi_token
@@ -321,6 +337,7 @@ def _get_mm_fields_config(
321337
return dict(
322338
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
323339
num_patches=MultiModalFieldConfig.batched("image"),
340+
image_embeds=MultiModalFieldConfig.batched("image"),
324341
)
325342

326343
def _get_prompt_updates(
@@ -333,7 +350,19 @@ def _get_prompt_updates(
333350
image_token = hf_processor.boi_token
334351

335352
def get_replacement_gemma3(item_idx: int):
336-
images = mm_items.get_items("image", ImageProcessorItems)
353+
images = mm_items.get_items(
354+
"image", (ImageEmbeddingItems, ImageProcessorItems)
355+
)
356+
357+
if isinstance(images, ImageEmbeddingItems):
358+
# For image embedding inputs, only support no crops cases
359+
# since it's not supported in hf processor anyway
360+
return self.info.get_image_repl(
361+
image_width=None,
362+
image_height=None,
363+
num_crops=0,
364+
processor=hf_processor,
365+
)
337366

338367
image_size = images.get_image_size(item_idx)
339368
return self.info.get_image_repl(
@@ -557,17 +586,19 @@ def _parse_and_validate_image_input(
557586
pixel_values = kwargs.pop("pixel_values", None)
558587
num_patches = kwargs.pop("num_patches", None)
559588
image_embeds = kwargs.pop("image_embeds", None)
560-
assert image_embeds is None, "Gemma3 does not support image_embeds."
561-
if pixel_values is None:
562-
return None
563589

564-
image_size = self.config.vision_config.image_size
565-
566-
return Gemma3ImagePixelInputs(
567-
pixel_values=pixel_values,
568-
num_patches=num_patches,
569-
resolve_bindings={"h": image_size, "w": image_size},
570-
)
590+
if pixel_values is not None:
591+
image_size = self.config.vision_config.image_size
592+
return Gemma3ImagePixelInputs(
593+
pixel_values=pixel_values,
594+
num_patches=num_patches,
595+
resolve_bindings={"h": image_size, "w": image_size},
596+
)
597+
elif image_embeds is not None:
598+
return Gemma3ImageEmbeddingInputs(
599+
image_embeds=image_embeds,
600+
type="image_embeds",
601+
)
571602

572603
def _image_pixels_to_features(
573604
self,
@@ -579,7 +610,9 @@ def _image_pixels_to_features(
579610
def _process_image_input(
580611
self,
581612
image_input: Gemma3ImageInputs,
582-
) -> list[torch.Tensor]:
613+
) -> torch.Tensor | list[torch.Tensor]:
614+
if image_input["type"] == "image_embeds":
615+
return image_input["image_embeds"]
583616
assert self.vision_tower is not None
584617

585618
pixel_values = image_input["pixel_values"]

vllm/multimodal/parse.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,9 @@ def __init__(
359359
)
360360
self.video_needs_metadata = video_needs_metadata
361361

362-
def _is_embeddings(
363-
self, data: object
362+
@classmethod
363+
def is_embeddings(
364+
cls, data: object
364365
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
365366
if isinstance(data, torch.Tensor):
366367
return data.ndim == 3
@@ -420,7 +421,7 @@ def _parse_audio_data(
420421
):
421422
return None
422423

423-
if self._is_embeddings(data):
424+
if self.is_embeddings(data):
424425
return AudioEmbeddingItems(data)
425426

426427
data_items: list[AudioItem]
@@ -458,7 +459,7 @@ def _parse_image_data(
458459
if self._is_empty(data):
459460
return None
460461

461-
if self._is_embeddings(data):
462+
if self.is_embeddings(data):
462463
return ImageEmbeddingItems(data)
463464

464465
if (
@@ -484,7 +485,7 @@ def _parse_video_data(
484485
if self._is_empty(data):
485486
return None
486487

487-
if self._is_embeddings(data):
488+
if self.is_embeddings(data):
488489
return VideoEmbeddingItems(data)
489490

490491
data_items: list[VideoItem]

vllm/v1/engine/processor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
1515
from vllm.multimodal.cache import processor_cache_from_config
1616
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
17+
from vllm.multimodal.parse import MultiModalDataParser
1718
from vllm.multimodal.processing import EncDecMultiModalProcessor
1819
from vllm.multimodal.utils import argsort_mm_positions
1920
from vllm.pooling_params import PoolingParams
@@ -340,7 +341,12 @@ def _extract_mm_data(p: PromptType):
340341

341342
mm_uuids: dict[str, list[str | None] | str] = {}
342343
for modality, data in mm_data.items():
343-
n = len(data) if isinstance(data, list) else 1
344+
# Hash each item for embedding inputs.
345+
n = (
346+
len(data)
347+
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
348+
else 1
349+
)
344350
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
345351
return mm_uuids
346352

0 commit comments

Comments
 (0)