Skip to content

Commit 329f612

Browse files
authored
Chunked Prefill VLM (#3188)
* add logic * working * add encoder cache free * fixes * fix idefics * update pixel_values * add improvements * add improvements * improve * nit * fix inputs_embeds * nit * optimizations * add prometheus port * rename vars * rename vars * nit * disable chunking for qwen * review comments * remove port * improve headdim * remove kwargs and redundant args * fix qwen2_5 * fix config image_token_id error * fix test * update paligemma * fix paligemma text * minor fix * fix qwen test * fix qwen test
1 parent 533eee5 commit 329f612

File tree

15 files changed

+1108
-509
lines changed

15 files changed

+1108
-509
lines changed

server/text_generation_server/models/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,6 @@
128128
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
129129
FlashGPTNeoXForCausalLM,
130130
)
131-
from text_generation_server.models.pali_gemma import (
132-
PaliGemmaBatch,
133-
)
134131
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
135132
PaliGemmaForConditionalGeneration,
136133
)
@@ -1196,6 +1193,7 @@ def get_model(
11961193
default_dtype=torch.bfloat16,
11971194
trust_remote_code=trust_remote_code,
11981195
lora_adapter_ids=lora_adapter_ids,
1196+
support_chunking=False,
11991197
)
12001198
elif FLASH_TRANSFORMERS_BACKEND:
12011199
from transformers import Gemma3ForConditionalGeneration as Gemma3Model
@@ -1208,6 +1206,7 @@ def get_model(
12081206
speculator=speculator,
12091207
dtype=torch.bfloat16,
12101208
trust_remote_code=trust_remote_code,
1209+
support_chunking=False,
12111210
)
12121211
elif sharded:
12131212
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma3"))
@@ -1523,6 +1522,8 @@ def get_model(
15231522
kv_cache_dtype=kv_cache_dtype,
15241523
trust_remote_code=trust_remote_code,
15251524
lora_adapter_ids=lora_adapter_ids,
1525+
# TODO: Fix bug in rust image_text_replacement implementation
1526+
support_chunking=False,
15261527
)
15271528
# TODO: Uncomment when transformers is refactored
15281529
# elif FLASH_TRANSFORMERS_BACKEND:
@@ -1554,6 +1555,8 @@ def get_model(
15541555
lora_adapter_ids=lora_adapter_ids,
15551556
config_class=Qwen2_5_VLConfig,
15561557
processor_class=Qwen2_5_VLProcessor,
1558+
# TODO: Fix bug in rust image_text_replacement implementation
1559+
support_chunking=False,
15571560
)
15581561
# TODO: Uncomment when transformers is refactored
15591562
# elif FLASH_TRANSFORMERS_BACKEND:
@@ -1583,6 +1586,7 @@ def get_model(
15831586
default_dtype=torch.bfloat16,
15841587
trust_remote_code=trust_remote_code,
15851588
lora_adapter_ids=lora_adapter_ids,
1589+
support_chunking=False,
15861590
)
15871591
# TODO: Uncomment when transformers is refactored and cross attn is added
15881592
# elif FLASH_TRANSFORMERS_BACKEND:
@@ -1676,7 +1680,6 @@ def get_model(
16761680
default_dtype=torch.bfloat16,
16771681
trust_remote_code=trust_remote_code,
16781682
lora_adapter_ids=lora_adapter_ids,
1679-
batch_class=PaliGemmaBatch,
16801683
)
16811684
elif FLASH_TRANSFORMERS_BACKEND:
16821685
from transformers import PaliGemmaForConditionalGeneration as PaliGemmaModel
@@ -1689,7 +1692,6 @@ def get_model(
16891692
speculator=speculator,
16901693
dtype=torch.bfloat16,
16911694
trust_remote_code=trust_remote_code,
1692-
batch_class=PaliGemmaBatch,
16931695
)
16941696
else:
16951697
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("PaliGemma"))

server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def __init__(self, prefix, config, weights):
700700
self.pad_token_id = (
701701
config.pad_token_id if config.pad_token_id is not None else -1
702702
)
703+
self.dtype = weights.dtype
703704

704705
def get_attention_mask(
705706
self,
@@ -762,9 +763,42 @@ def get_attention_mask(
762763
else:
763764
return torch.where(full_attention_mask, 0, min_dtype).to(device)
764765

765-
def forward(
766+
def get_vision_embeds(
767+
self,
768+
pixel_values: torch.FloatTensor,
769+
pixel_attention_mask: Optional[torch.FloatTensor] = None,
770+
image_sizes: Optional[torch.Tensor] = None,
771+
image_grid_thw: Optional[torch.LongTensor] = None,
772+
):
773+
pixel_values = pixel_values.to(dtype=self.dtype)
774+
image_outputs = self.vision_model(pixel_values)
775+
vision_outputs = self.post_vision_model_layernorm(
776+
image_outputs.last_hidden_state
777+
)
778+
image_features = self.multimodal_projector(vision_outputs)
779+
image_features = image_features.view(-1, image_features.shape[-1])
780+
return image_features
781+
782+
def get_inputs_embeds(
766783
self,
767784
input_ids: torch.Tensor,
785+
vision_embeds: torch.Tensor = None,
786+
):
787+
inputs_embeds = self.text_model.embed_tokens(input_ids)
788+
789+
if vision_embeds is not None:
790+
# Replace the image token embeddings with the vision features
791+
image_token_mask = (input_ids == self.config.image_token_index).to(
792+
input_ids.device
793+
)
794+
inputs_embeds[image_token_mask] = vision_embeds.view(
795+
-1, vision_embeds.shape[-1]
796+
)
797+
return inputs_embeds
798+
799+
def forward(
800+
self,
801+
inputs_embeds: torch.Tensor,
768802
position_ids: torch.Tensor,
769803
cu_seqlen_prefill: Optional[torch.Tensor],
770804
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@@ -777,35 +811,12 @@ def forward(
777811
pixel_values: torch.FloatTensor = None,
778812
# Unused here
779813
attention_mask: Optional[torch.BoolTensor] = None,
780-
pixel_attention_mask: Optional[torch.BoolTensor] = None,
781-
image_sizes: Optional[torch.Tensor] = None,
782814
adapter_data: Optional[torch.Tensor] = None,
783-
image_grid_thw: Optional[torch.LongTensor] = None,
784815
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
785-
inputs_embeds = self.text_model.embed_tokens(input_ids)
786816
if cu_seqlen_prefill is not None:
787817
max_s += 1
788818
position_ids += 1
789819

790-
if pixel_values is not None:
791-
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
792-
image_outputs = self.vision_model(pixel_values)
793-
vision_outputs = self.post_vision_model_layernorm(
794-
image_outputs.last_hidden_state
795-
)
796-
image_features = self.multimodal_projector(vision_outputs)
797-
798-
image_token_mask = (input_ids == self.config.image_token_index).to(
799-
input_ids.device
800-
)
801-
inputs_embeds[image_token_mask] = image_features.view(
802-
-1, image_features.shape[-1]
803-
)
804-
attention_mask = self.get_attention_mask(
805-
input_ids,
806-
cu_seqlen_prefill,
807-
inputs_embeds.dtype,
808-
)
809820
# Use flash attention for text-only input
810821
# else:
811822
# if cu_seqlen_prefill is not None:

server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,10 @@ def __init__(self, prefix: str, config, weights, layer_id):
116116
)
117117
self.num_heads = config.num_attention_heads
118118
self.hidden_size = config.hidden_size
119-
if hasattr(config, "head_dim"):
119+
if getattr(config, "head_dim", None) is not None:
120120
self.head_size = config.head_dim
121121
else:
122122
self.head_size = self.hidden_size // self.num_heads
123-
124123
self.rotary_emb = PositionRotaryEmbedding.static(
125124
config=config,
126125
dim=self.head_size,

server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,40 @@ def __init__(self, prefix, config, weights):
6262
self.pad_token_id = (
6363
config.pad_token_id if config.pad_token_id is not None else -1
6464
)
65+
self.dtype = weights.dtype
6566

66-
def forward(
67+
def get_vision_embeds(
68+
self,
69+
pixel_values: torch.FloatTensor,
70+
pixel_attention_mask: Optional[torch.FloatTensor] = None,
71+
image_sizes: Optional[torch.Tensor] = None,
72+
image_grid_thw: Optional[torch.LongTensor] = None,
73+
):
74+
pixel_values = pixel_values.to(dtype=self.dtype)
75+
image_outputs = self.vision_tower(pixel_values)
76+
last_hidden_state = self.post_vision_tower_layernorm(
77+
image_outputs.last_hidden_state
78+
)
79+
image_features = self.multi_modal_projector(last_hidden_state)
80+
image_features = image_features.view(-1, image_features.shape[-1])
81+
return image_features
82+
83+
def get_inputs_embeds(
6784
self,
6885
input_ids: torch.Tensor,
86+
vision_embeds: torch.Tensor = None,
87+
):
88+
inputs_embeds = self.text_model.embed_tokens(input_ids)
89+
90+
if vision_embeds is not None:
91+
mask = input_ids == self.config.image_token_index
92+
inputs_embeds[mask] = vision_embeds
93+
94+
return inputs_embeds
95+
96+
def forward(
97+
self,
98+
inputs_embeds: torch.Tensor,
6999
position_ids: torch.Tensor,
70100
cu_seqlen_prefill: Optional[torch.Tensor],
71101
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
@@ -75,33 +105,15 @@ def forward(
75105
max_s: int,
76106
prefill_cache_indices: Optional[torch.Tensor] = None,
77107
lm_head_indices: Optional[torch.Tensor] = None,
78-
pixel_values: torch.FloatTensor = None,
79108
# Unused here
80-
pixel_attention_mask: Optional[torch.BoolTensor] = None,
81-
image_sizes: Optional[torch.Tensor] = None,
109+
attention_mask: Optional[torch.BoolTensor] = None,
82110
adapter_data: Optional[torch.Tensor] = None,
83-
image_grid_thw: Optional[torch.LongTensor] = None,
84111
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
85-
inputs_embeds = self.text_model.embed_tokens(input_ids)
86112
# TODO This is odd but apparently pali gemma position ids start at 1.
87113
if cu_seqlen_prefill is not None:
88114
max_s += 1
89115
position_ids += 1
90116

91-
if pixel_values is not None:
92-
pixel_values = pixel_values.to(dtype=inputs_embeds.dtype)
93-
image_outputs = self.vision_tower(pixel_values)
94-
last_hidden_state = self.post_vision_tower_layernorm(
95-
image_outputs.last_hidden_state
96-
)
97-
image_features = self.multi_modal_projector(last_hidden_state)
98-
99-
# mask where image or padding tokens
100-
mask = input_ids == self.config.image_token_index
101-
102-
# insert image features into input embeddings
103-
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
104-
105117
hidden_states = self.text_model.model(
106118
inputs_embeds=inputs_embeds,
107119
position_ids=position_ids,

0 commit comments

Comments
 (0)