22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import math
44from collections .abc import Iterable , Mapping , Sequence
5- from typing import Annotated , Any , Literal
5+ from typing import Annotated , Any , Literal , TypeAlias
66
77import torch
88from torch import nn
2020 MultiModalFieldConfig ,
2121 MultiModalKwargsItems ,
2222)
23- from vllm .multimodal .parse import ImageProcessorItems , ImageSize , MultiModalDataItems
23+ from vllm .multimodal .parse import (
24+ ImageEmbeddingItems ,
25+ ImageProcessorItems ,
26+ ImageSize ,
27+ MultiModalDataItems ,
28+ )
2429from vllm .multimodal .processing import (
2530 BaseMultiModalProcessor ,
2631 BaseProcessingInfo ,
@@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
7176 num_patches : Annotated [torch .Tensor , TensorShape ("bn" )]
7277
7378
74- Gemma3ImageInputs = Gemma3ImagePixelInputs
79+ class Gemma3ImageEmbeddingInputs (TensorSchema ):
80+ type : Literal ["image_embeds" ] = "image_embeds"
81+ image_embeds : Annotated [
82+ torch .Tensor ,
83+ TensorShape ("ni" , "nf" , "hs" ),
84+ ]
85+
86+
87+ Gemma3ImageInputs : TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
7588
7689
7790class Gemma3ProcessingInfo (BaseProcessingInfo ):
@@ -178,20 +191,23 @@ def get_num_crops(
178191 def get_image_repl (
179192 self ,
180193 * ,
181- image_width : int ,
182- image_height : int ,
194+ image_width : int | None ,
195+ image_height : int | None ,
196+ num_crops : int | None = None ,
183197 processor : Gemma3Processor | None ,
184198 ) -> PromptUpdateDetails [str ]:
185199 if processor is None :
186200 processor = self .get_hf_processor ()
187201
188202 boi_token = processor .boi_token
189203
190- num_crops = self .get_num_crops (
191- image_width = image_width ,
192- image_height = image_height ,
193- processor = processor ,
194- )
204+ if num_crops is None :
205+ assert image_width is not None and image_height is not None
206+ num_crops = self .get_num_crops (
207+ image_width = image_width ,
208+ image_height = image_height ,
209+ processor = processor ,
210+ )
195211
196212 if num_crops == 0 :
197213 image_text = boi_token
@@ -321,6 +337,7 @@ def _get_mm_fields_config(
321337 return dict (
322338 pixel_values = MultiModalFieldConfig .flat_from_sizes ("image" , num_patches ),
323339 num_patches = MultiModalFieldConfig .batched ("image" ),
340+ image_embeds = MultiModalFieldConfig .batched ("image" ),
324341 )
325342
326343 def _get_prompt_updates (
@@ -333,7 +350,19 @@ def _get_prompt_updates(
333350 image_token = hf_processor .boi_token
334351
335352 def get_replacement_gemma3 (item_idx : int ):
336- images = mm_items .get_items ("image" , ImageProcessorItems )
353+ images = mm_items .get_items (
354+ "image" , (ImageEmbeddingItems , ImageProcessorItems )
355+ )
356+
357+ if isinstance (images , ImageEmbeddingItems ):
358+ # For image embedding inputs, only support no crops cases
359+ # since it's not supported in hf processor anyway
360+ return self .info .get_image_repl (
361+ image_width = None ,
362+ image_height = None ,
363+ num_crops = 0 ,
364+ processor = hf_processor ,
365+ )
337366
338367 image_size = images .get_image_size (item_idx )
339368 return self .info .get_image_repl (
@@ -557,17 +586,19 @@ def _parse_and_validate_image_input(
557586 pixel_values = kwargs .pop ("pixel_values" , None )
558587 num_patches = kwargs .pop ("num_patches" , None )
559588 image_embeds = kwargs .pop ("image_embeds" , None )
560- assert image_embeds is None , "Gemma3 does not support image_embeds."
561- if pixel_values is None :
562- return None
563589
564- image_size = self .config .vision_config .image_size
565-
566- return Gemma3ImagePixelInputs (
567- pixel_values = pixel_values ,
568- num_patches = num_patches ,
569- resolve_bindings = {"h" : image_size , "w" : image_size },
570- )
590+ if pixel_values is not None :
591+ image_size = self .config .vision_config .image_size
592+ return Gemma3ImagePixelInputs (
593+ pixel_values = pixel_values ,
594+ num_patches = num_patches ,
595+ resolve_bindings = {"h" : image_size , "w" : image_size },
596+ )
597+ elif image_embeds is not None :
598+ return Gemma3ImageEmbeddingInputs (
599+ image_embeds = image_embeds ,
600+ type = "image_embeds" ,
601+ )
571602
572603 def _image_pixels_to_features (
573604 self ,
@@ -579,7 +610,9 @@ def _image_pixels_to_features(
579610 def _process_image_input (
580611 self ,
581612 image_input : Gemma3ImageInputs ,
582- ) -> list [torch .Tensor ]:
613+ ) -> torch .Tensor | list [torch .Tensor ]:
614+ if image_input ["type" ] == "image_embeds" :
615+ return image_input ["image_embeds" ]
583616 assert self .vision_tower is not None
584617
585618 pixel_values = image_input ["pixel_values" ]
0 commit comments