Skip to content

Commit 04de905

Browse files
authored
[Model] support input image embedding for minicpmv (vllm-project#9237)
1 parent 07c11cf commit 04de905

File tree

3 files changed

+101
-43
lines changed

3 files changed

+101
-43
lines changed

docs/source/models/supported_models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ Text Generation
378378
- ✅︎
379379
* - :code:`MiniCPMV`
380380
- MiniCPM-V
381-
- Image\ :sup:`+`
381+
- Image\ :sup:`E+`
382382
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
383383
- ✅︎
384384
- ✅︎

docs/source/models/vlm.rst

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,19 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
5757
print(generated_text)
5858
5959
# Inference with image embeddings as input with additional parameters
60-
# Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding.
61-
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
62-
image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3)
60+
# Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters.
61+
mm_data = {}
62+
63+
image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
64+
# For Qwen2VL, image_grid_thw is needed to calculate positional encoding.
65+
mm_data['image'] = {
66+
"image_embeds": image_embeds,
67+
"image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3),
68+
}
69+
# For MiniCPM-V, image_size_list is needed to calculate details of the sliced image.
6370
mm_data['image'] = {
6471
"image_embeds": image_embeds,
65-
"image_grid_thw": image_grid_thw,
72+
"image_size_list": [image.size] # list of image sizes
6673
}
6774
outputs = llm.generate({
6875
"prompt": prompt,

vllm/model_executor/models/minicpmv.py

Lines changed: 89 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
import math
2525
import re
2626
from functools import partial
27-
from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple,
28-
TypedDict)
27+
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
28+
Tuple, TypedDict, Union)
2929

3030
import torch
3131
import torch.types
@@ -65,10 +65,12 @@
6565
"llm.lm_head": "lm_head",
6666
}
6767

68+
RawImageType = Union[Image.Image, torch.Tensor]
6869

69-
class MiniCPMVImageInput(TypedDict):
70+
71+
class MiniCPMVRawImageInput(TypedDict):
7072
"""Input mapper input with auxiliary data for computing image bounds."""
71-
image: Image.Image
73+
image: RawImageType
7274

7375
# Image bounds token ids in 0-dim scaler tensor.
7476
im_start_id: torch.Tensor
@@ -78,7 +80,8 @@ class MiniCPMVImageInput(TypedDict):
7880

7981

8082
class MiniCPMVImagePixelInputs(TypedDict):
81-
pixel_values: List[torch.Tensor]
83+
type: Literal["pixel_values"]
84+
data: List[torch.Tensor]
8285
"""
8386
Shape: `(batch_size * num_images, num_channels, height, width)`
8487
@@ -101,6 +104,27 @@ class MiniCPMVImagePixelInputs(TypedDict):
101104
"""
102105

103106

107+
class MiniCPMVImageEmbeddingInputs(TypedDict):
108+
type: Literal["image_embeds"]
109+
data: torch.Tensor
110+
"""
111+
Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
112+
113+
`hidden_size` must match the hidden size of language model backbone.
114+
instead of a batched tensor.
115+
"""
116+
117+
image_bounds: torch.Tensor
118+
"""
119+
Shape: `(batch_size * num_images, 2)`
120+
121+
This should be in `(start, stop)` format.
122+
"""
123+
124+
125+
MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
126+
MiniCPMVImageEmbeddingInputs]
127+
104128
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
105129

106130

@@ -194,22 +218,22 @@ def forward(self, x: torch.Tensor,
194218

195219

196220
def _build_image_input(ctx: InputContext,
197-
image: Image.Image) -> MiniCPMVImageInput:
221+
image: RawImageType) -> MiniCPMVRawImageInput:
198222
tokenizer = cached_get_tokenizer(
199223
ctx.model_config.tokenizer,
200224
trust_remote_code=ctx.model_config.trust_remote_code)
201225
if hasattr(tokenizer, "slice_start_id"):
202-
return MiniCPMVImageInput(
226+
return MiniCPMVRawImageInput(
203227
image=image,
204228
im_start_id=torch.tensor(tokenizer.im_start_id),
205229
im_end_id=torch.tensor(tokenizer.im_end_id),
206230
slice_start_id=torch.tensor(tokenizer.slice_start_id),
207231
slice_end_id=torch.tensor(tokenizer.slice_end_id))
208232
else:
209-
return MiniCPMVImageInput(image=image,
210-
im_start_id=torch.tensor(
211-
tokenizer.im_start_id),
212-
im_end_id=torch.tensor(tokenizer.im_end_id))
233+
return MiniCPMVRawImageInput(
234+
image=image,
235+
im_start_id=torch.tensor(tokenizer.im_start_id),
236+
im_end_id=torch.tensor(tokenizer.im_end_id))
213237

214238

215239
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
@@ -280,20 +304,25 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
280304

281305
pattern = "(<image>./</image>)"
282306
images = multi_modal_data["image"]
283-
if isinstance(images, Image.Image):
284-
images = [images]
285307
image_tags = re.findall(pattern, prompt)
286-
287308
if len(image_tags) == 0:
288309
new_token_ids = token_ids
289310
new_prompt = prompt
290311
else:
312+
if isinstance(images, dict):
313+
image_size_list = images.get("image_size_list")
314+
images = [images.get("image_embeds")]
315+
else:
316+
if isinstance(images, Image.Image):
317+
images = [images]
318+
image_size_list = [image.size for image in images]
319+
291320
text_chunks = prompt.split(pattern)
292321
new_prompt_chunks: List[str] = []
293-
for i in range(len(images)):
322+
for i in range(len(image_size_list)):
294323
new_prompt_chunks += [
295324
text_chunks[i],
296-
get_placeholder(images[i].size, i)
325+
get_placeholder(image_size_list[i], i)
297326
]
298327
new_prompt_chunks.append(text_chunks[-1])
299328
new_prompt = "".join(new_prompt_chunks)
@@ -323,9 +352,15 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
323352
if not isinstance(data, list):
324353
raise ValueError(
325354
"Image input must be list of MiniCPMVImageInput, got (%s)", data)
326-
batch_data = image_processor \
327-
.preprocess([img["image"] for img in data], return_tensors="pt") \
328-
.data
355+
356+
if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor):
357+
batch_data = {
358+
"image_embeds": data[0]['image'],
359+
}
360+
else:
361+
batch_data = image_processor \
362+
.preprocess([img["image"] for img in data], return_tensors="pt") \
363+
.data
329364

330365
if len(data) > 0:
331366
batch_data["im_start_id"] = data[0]["im_start_id"]
@@ -380,7 +415,7 @@ def __init__(
380415
def get_embedding(
381416
self,
382417
input_ids: torch.Tensor,
383-
image_inputs: Optional[MiniCPMVImagePixelInputs],
418+
image_inputs: Optional[MiniCPMVImageInputs],
384419
) -> Tuple[torch.Tensor, torch.Tensor]:
385420
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
386421
if hasattr(self.config, "scale_emb"):
@@ -389,7 +424,12 @@ def get_embedding(
389424
if image_inputs is None: # No image
390425
vision_hidden_states = torch.tensor([], device=input_ids.device)
391426
else:
392-
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
427+
if image_inputs["type"] == "image_embeds":
428+
vision_hidden_states = (image_inputs["data"].type(
429+
vlm_embedding.dtype).to(vlm_embedding.device))
430+
else:
431+
vision_hidden_states = self.get_vision_hidden_states(
432+
image_inputs)
393433

394434
# See NOTE in _parse_and_validate_inputs
395435
image_bounds = image_inputs["image_bounds"]
@@ -440,9 +480,23 @@ def _parse_and_validate_inputs(
440480
self,
441481
input_ids: torch.Tensor,
442482
**kwargs: object,
443-
) -> Optional[MiniCPMVImagePixelInputs]:
483+
) -> Optional[MiniCPMVImageInputs]:
444484
pixel_values = kwargs.pop("pixel_values", [])
445485
tgt_sizes = kwargs.pop("tgt_sizes", [])
486+
im_start_id = kwargs.pop("im_start_id", None)
487+
im_end_id = kwargs.pop("im_end_id", None)
488+
slice_start_id = kwargs.pop("slice_start_id", None)
489+
slice_end_id = kwargs.pop("slice_end_id", None)
490+
image_embeds = kwargs.pop("image_embeds", None)
491+
492+
if image_embeds is not None:
493+
return MiniCPMVImageEmbeddingInputs(
494+
image_bounds=self._get_image_bounds(input_ids, im_start_id,
495+
im_end_id, slice_start_id,
496+
slice_end_id),
497+
data=image_embeds,
498+
type="image_embeds",
499+
)
446500

447501
if not isinstance(pixel_values, (torch.Tensor, list)):
448502
raise ValueError("Incorrect type of pixel values. "
@@ -477,19 +531,16 @@ def _parse_and_validate_inputs(
477531
if len(pixel_values_flat) == 0:
478532
return None
479533

480-
im_start_id = kwargs.pop("im_start_id", None)
481-
im_end_id = kwargs.pop("im_end_id", None)
482-
slice_start_id = kwargs.pop("slice_start_id", None)
483-
slice_end_id = kwargs.pop("slice_end_id", None)
484534
if im_start_id is None:
485535
return None
486536

487537
return MiniCPMVImagePixelInputs(
488538
image_bounds=self._get_image_bounds(input_ids, im_start_id,
489539
im_end_id, slice_start_id,
490540
slice_end_id),
491-
pixel_values=pixel_values_flat,
541+
data=pixel_values_flat,
492542
tgt_sizes=torch.stack(tgt_sizes_flat),
543+
type="pixel_values",
493544
)
494545

495546
def forward(
@@ -610,8 +661,8 @@ def get_vision_embedding(
610661
) -> torch.Tensor:
611662
raise NotImplementedError
612663

613-
def get_vision_hidden_states(
614-
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
664+
def get_vision_hidden_states(self,
665+
data: MiniCPMVImageInputs) -> torch.Tensor:
615666
raise NotImplementedError
616667

617668
def is_default_weight_loading(self, name: str) -> bool:
@@ -705,9 +756,9 @@ def get_vision_embedding(
705756
res.append(self.resampler(vision_embedding, tgt_size))
706757
return torch.vstack(res)
707758

708-
def get_vision_hidden_states(
709-
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
710-
pixel_values = data["pixel_values"]
759+
def get_vision_hidden_states(self,
760+
data: MiniCPMVImageInputs) -> torch.Tensor:
761+
pixel_values = data["data"]
711762

712763
return self.get_vision_embedding(pixel_values)
713764

@@ -793,9 +844,9 @@ def get_vision_embedding(
793844
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
794845
return vision_embedding
795846

796-
def get_vision_hidden_states(
797-
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
798-
pixel_values = data["pixel_values"]
847+
def get_vision_hidden_states(self,
848+
data: MiniCPMVImageInputs) -> torch.Tensor:
849+
pixel_values = data["data"]
799850
tgt_sizes = data["tgt_sizes"]
800851

801852
device = self.vpm.embeddings.position_embedding.weight.device
@@ -909,9 +960,9 @@ def get_vision_embedding(
909960
)
910961
return vision_embedding
911962

912-
def get_vision_hidden_states(
913-
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
914-
pixel_values = data["pixel_values"]
963+
def get_vision_hidden_states(self,
964+
data: MiniCPMVImageInputs) -> torch.Tensor:
965+
pixel_values = data["data"]
915966
tgt_sizes = data["tgt_sizes"]
916967

917968
device = self.vpm.embeddings.position_embedding.weight.device

0 commit comments

Comments
 (0)