2424import math
2525import re
2626from functools import partial
27- from typing import (Any , Callable , Iterable , List , Mapping , Optional , Tuple ,
28- TypedDict )
27+ from typing import (Any , Callable , Iterable , List , Literal , Mapping , Optional ,
28+ Tuple , TypedDict , Union )
2929
3030import torch
3131import torch .types
6565 "llm.lm_head" : "lm_head" ,
6666}
6767
68+ RawImageType = Union [Image .Image , torch .Tensor ]
6869
69- class MiniCPMVImageInput (TypedDict ):
70+
71+ class MiniCPMVRawImageInput (TypedDict ):
7072 """Input mapper input with auxiliary data for computing image bounds."""
71- image : Image . Image
73+ image : RawImageType
7274
7375 # Image bounds token ids in 0-dim scaler tensor.
7476 im_start_id : torch .Tensor
@@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict):
7880
7981
8082class MiniCPMVImagePixelInputs (TypedDict ):
81- pixel_values : List [torch .Tensor ]
83+ type : Literal ["pixel_values" ]
84+ data : List [torch .Tensor ]
8285 """
8386 Shape: `(batch_size * num_images, num_channels, height, width)`
8487
@@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict):
101104 """
102105
103106
107+ class MiniCPMVImageEmbeddingInputs (TypedDict ):
108+ type : Literal ["image_embeds" ]
109+ data : torch .Tensor
110+ """
111+ Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
112+
113+ `hidden_size` must match the hidden size of language model backbone.
114+ instead of a batched tensor.
115+ """
116+
117+ image_bounds : torch .Tensor
118+ """
119+ Shape: `(batch_size * num_images, 2)`
120+
121+ This should be in `(start, stop)` format.
122+ """
123+
124+
125+ MiniCPMVImageInputs = Union [MiniCPMVImagePixelInputs ,
126+ MiniCPMVImageEmbeddingInputs ]
127+
104128DEFAULT_LN = partial (nn .LayerNorm , eps = 1e-6 )
105129
106130
@@ -194,22 +218,22 @@ def forward(self, x: torch.Tensor,
194218
195219
196220def _build_image_input (ctx : InputContext ,
197- image : Image . Image ) -> MiniCPMVImageInput :
221+ image : RawImageType ) -> MiniCPMVRawImageInput :
198222 tokenizer = cached_get_tokenizer (
199223 ctx .model_config .tokenizer ,
200224 trust_remote_code = ctx .model_config .trust_remote_code )
201225 if hasattr (tokenizer , "slice_start_id" ):
202- return MiniCPMVImageInput (
226+ return MiniCPMVRawImageInput (
203227 image = image ,
204228 im_start_id = torch .tensor (tokenizer .im_start_id ),
205229 im_end_id = torch .tensor (tokenizer .im_end_id ),
206230 slice_start_id = torch .tensor (tokenizer .slice_start_id ),
207231 slice_end_id = torch .tensor (tokenizer .slice_end_id ))
208232 else :
209- return MiniCPMVImageInput ( image = image ,
210- im_start_id = torch . tensor (
211- tokenizer .im_start_id ),
212- im_end_id = torch .tensor (tokenizer .im_end_id ))
233+ return MiniCPMVRawImageInput (
234+ image = image ,
235+ im_start_id = torch . tensor ( tokenizer .im_start_id ),
236+ im_end_id = torch .tensor (tokenizer .im_end_id ))
213237
214238
215239def get_version_by_config (config : PretrainedConfig ) -> Tuple [int , ...]:
@@ -280,20 +304,25 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
280304
281305 pattern = "(<image>./</image>)"
282306 images = multi_modal_data ["image" ]
283- if isinstance (images , Image .Image ):
284- images = [images ]
285307 image_tags = re .findall (pattern , prompt )
286-
287308 if len (image_tags ) == 0 :
288309 new_token_ids = token_ids
289310 new_prompt = prompt
290311 else :
312+ if isinstance (images , dict ):
313+ image_size_list = images .get ("image_size_list" )
314+ images = [images .get ("image_embeds" )]
315+ else :
316+ if isinstance (images , Image .Image ):
317+ images = [images ]
318+ image_size_list = [image .size for image in images ]
319+
291320 text_chunks = prompt .split (pattern )
292321 new_prompt_chunks : List [str ] = []
293- for i in range (len (images )):
322+ for i in range (len (image_size_list )):
294323 new_prompt_chunks += [
295324 text_chunks [i ],
296- get_placeholder (images [i ]. size , i )
325+ get_placeholder (image_size_list [i ], i )
297326 ]
298327 new_prompt_chunks .append (text_chunks [- 1 ])
299328 new_prompt = "" .join (new_prompt_chunks )
@@ -323,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
323352 if not isinstance (data , list ):
324353 raise ValueError (
325354 "Image input must be list of MiniCPMVImageInput, got (%s)" , data )
326- batch_data = image_processor \
327- .preprocess ([img ["image" ] for img in data ], return_tensors = "pt" ) \
328- .data
355+
356+ if len (data ) > 0 and isinstance (data [0 ]['image' ], torch .Tensor ):
357+ batch_data = {
358+ "image_embeds" : data [0 ]['image' ],
359+ }
360+ else :
361+ batch_data = image_processor \
362+ .preprocess ([img ["image" ] for img in data ], return_tensors = "pt" ) \
363+ .data
329364
330365 if len (data ) > 0 :
331366 batch_data ["im_start_id" ] = data [0 ]["im_start_id" ]
@@ -380,7 +415,7 @@ def __init__(
380415 def get_embedding (
381416 self ,
382417 input_ids : torch .Tensor ,
383- image_inputs : Optional [MiniCPMVImagePixelInputs ],
418+ image_inputs : Optional [MiniCPMVImageInputs ],
384419 ) -> Tuple [torch .Tensor , torch .Tensor ]:
385420 vlm_embedding : torch .Tensor = self .llm .embed_tokens (input_ids )
386421 if hasattr (self .config , "scale_emb" ):
@@ -389,7 +424,12 @@ def get_embedding(
389424 if image_inputs is None : # No image
390425 vision_hidden_states = torch .tensor ([], device = input_ids .device )
391426 else :
392- vision_hidden_states = self .get_vision_hidden_states (image_inputs )
427+ if image_inputs ["type" ] == "image_embeds" :
428+ vision_hidden_states = (image_inputs ["data" ].type (
429+ vlm_embedding .dtype ).to (vlm_embedding .device ))
430+ else :
431+ vision_hidden_states = self .get_vision_hidden_states (
432+ image_inputs )
393433
394434 # See NOTE in _parse_and_validate_inputs
395435 image_bounds = image_inputs ["image_bounds" ]
@@ -440,9 +480,23 @@ def _parse_and_validate_inputs(
440480 self ,
441481 input_ids : torch .Tensor ,
442482 ** kwargs : object ,
443- ) -> Optional [MiniCPMVImagePixelInputs ]:
483+ ) -> Optional [MiniCPMVImageInputs ]:
444484 pixel_values = kwargs .pop ("pixel_values" , [])
445485 tgt_sizes = kwargs .pop ("tgt_sizes" , [])
486+ im_start_id = kwargs .pop ("im_start_id" , None )
487+ im_end_id = kwargs .pop ("im_end_id" , None )
488+ slice_start_id = kwargs .pop ("slice_start_id" , None )
489+ slice_end_id = kwargs .pop ("slice_end_id" , None )
490+ image_embeds = kwargs .pop ("image_embeds" , None )
491+
492+ if image_embeds is not None :
493+ return MiniCPMVImageEmbeddingInputs (
494+ image_bounds = self ._get_image_bounds (input_ids , im_start_id ,
495+ im_end_id , slice_start_id ,
496+ slice_end_id ),
497+ data = image_embeds ,
498+ type = "image_embeds" ,
499+ )
446500
447501 if not isinstance (pixel_values , (torch .Tensor , list )):
448502 raise ValueError ("Incorrect type of pixel values. "
@@ -477,19 +531,16 @@ def _parse_and_validate_inputs(
477531 if len (pixel_values_flat ) == 0 :
478532 return None
479533
480- im_start_id = kwargs .pop ("im_start_id" , None )
481- im_end_id = kwargs .pop ("im_end_id" , None )
482- slice_start_id = kwargs .pop ("slice_start_id" , None )
483- slice_end_id = kwargs .pop ("slice_end_id" , None )
484534 if im_start_id is None :
485535 return None
486536
487537 return MiniCPMVImagePixelInputs (
488538 image_bounds = self ._get_image_bounds (input_ids , im_start_id ,
489539 im_end_id , slice_start_id ,
490540 slice_end_id ),
491- pixel_values = pixel_values_flat ,
541+ data = pixel_values_flat ,
492542 tgt_sizes = torch .stack (tgt_sizes_flat ),
543+ type = "pixel_values" ,
493544 )
494545
495546 def forward (
@@ -610,8 +661,8 @@ def get_vision_embedding(
610661 ) -> torch .Tensor :
611662 raise NotImplementedError
612663
613- def get_vision_hidden_states (
614- self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
664+ def get_vision_hidden_states (self ,
665+ data : MiniCPMVImageInputs ) -> torch .Tensor :
615666 raise NotImplementedError
616667
617668 def is_default_weight_loading (self , name : str ) -> bool :
@@ -705,9 +756,9 @@ def get_vision_embedding(
705756 res .append (self .resampler (vision_embedding , tgt_size ))
706757 return torch .vstack (res )
707758
708- def get_vision_hidden_states (
709- self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
710- pixel_values = data ["pixel_values " ]
759+ def get_vision_hidden_states (self ,
760+ data : MiniCPMVImageInputs ) -> torch .Tensor :
761+ pixel_values = data ["data " ]
711762
712763 return self .get_vision_embedding (pixel_values )
713764
@@ -793,9 +844,9 @@ def get_vision_embedding(
793844 vision_embedding = self .resampler (vision_embedding , tgt_sizes )
794845 return vision_embedding
795846
796- def get_vision_hidden_states (
797- self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
798- pixel_values = data ["pixel_values " ]
847+ def get_vision_hidden_states (self ,
848+ data : MiniCPMVImageInputs ) -> torch .Tensor :
849+ pixel_values = data ["data " ]
799850 tgt_sizes = data ["tgt_sizes" ]
800851
801852 device = self .vpm .embeddings .position_embedding .weight .device
@@ -909,9 +960,9 @@ def get_vision_embedding(
909960 )
910961 return vision_embedding
911962
912- def get_vision_hidden_states (
913- self , data : MiniCPMVImagePixelInputs ) -> torch .Tensor :
914- pixel_values = data ["pixel_values " ]
963+ def get_vision_hidden_states (self ,
964+ data : MiniCPMVImageInputs ) -> torch .Tensor :
965+ pixel_values = data ["data " ]
915966 tgt_sizes = data ["tgt_sizes" ]
916967
917968 device = self .vpm .embeddings .position_embedding .weight .device
0 commit comments