Skip to content

Commit 64192d5

Browse files
[Bugfix] Revert custom attention mask for gemma3-mm (#28995)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
1 parent fe25772 commit 64192d5

File tree

4 files changed

+1
-172
lines changed

4 files changed

+1
-172
lines changed

vllm/config/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
try_get_generation_config,
3333
try_get_safetensors_metadata,
3434
try_get_tokenizer_config,
35-
uses_custom_attention_masks,
3635
uses_mrope,
3736
)
3837
from vllm.transformers_utils.gguf_utils import (
@@ -1625,10 +1624,6 @@ def uses_alibi(self) -> bool:
16251624
def uses_mrope(self) -> bool:
16261625
return uses_mrope(self.hf_config)
16271626

1628-
@property
1629-
def uses_custom_attention_masks(self) -> bool:
1630-
return uses_custom_attention_masks(self.hf_config)
1631-
16321627
@property
16331628
def is_multimodal_model(self) -> bool:
16341629
return self.multimodal_config is not None

vllm/model_executor/models/gemma3_mm.py

Lines changed: 1 addition & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def _process_image_input(
596596
def get_language_model(self) -> torch.nn.Module:
597597
return self.language_model
598598

599-
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
599+
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
600600
image_input = self._parse_and_validate_image_input(**kwargs)
601601
if image_input is None:
602602
return []
@@ -644,142 +644,6 @@ def forward(
644644

645645
return hidden_states
646646

647-
def generate_attention_masks(
648-
self,
649-
input_ids: torch.Tensor,
650-
positions: torch.Tensor,
651-
mask_dtype: torch.dtype,
652-
) -> dict[str, Any]:
653-
"""Generate custom attention masks for Gemma3 multimodal inputs.
654-
655-
This is called by V1 engine's gpu_model_runner during preprocessing
656-
to generate attention masks that allow bidirectional attention between
657-
image tokens while maintaining causal attention for text.
658-
"""
659-
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
660-
# This is a HACK. Fix this.
661-
start_indices = (positions == 0).cpu().nonzero()
662-
num_seqs = len(start_indices)
663-
seq_lens = []
664-
for i in range(num_seqs):
665-
start_idx = start_indices[i]
666-
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
667-
seq_lens.append(end_idx - start_idx)
668-
669-
global_attn_masks = []
670-
local_attn_masks = []
671-
start_idx = 0
672-
for seq_idx, seq_len in enumerate(seq_lens):
673-
end_idx = start_idx + seq_len
674-
input_token_ids = input_ids[start_idx:end_idx]
675-
676-
# Find image token positions
677-
img_pos = input_token_ids == self.config.image_token_index
678-
679-
start_idx = end_idx
680-
681-
# Create a global causal mask
682-
global_attn_mask = torch.empty(
683-
1,
684-
1,
685-
seq_len,
686-
seq_len,
687-
dtype=mask_dtype,
688-
device=input_ids.device,
689-
)
690-
global_attn_mask.fill_(float("-inf"))
691-
# Fill the lower triangle with 0 (causal attention)
692-
global_attn_mask = global_attn_mask.triu(diagonal=1)
693-
694-
# Enable bidirectional attention between image tokens
695-
img_mask = torch.zeros_like(global_attn_mask)
696-
img_mask[:, :, :, img_pos] += 1
697-
img_mask[:, :, img_pos, :] += 1
698-
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
699-
global_attn_masks.append(global_attn_mask)
700-
701-
# GGUF compatibility: config might be Gemma3TextConfig directly
702-
text_config = getattr(self.config, "text_config", self.config)
703-
sliding_window = text_config.sliding_window
704-
if sliding_window is not None:
705-
# Create a local causal mask with sliding window (1024)
706-
local_attn_mask = torch.ones_like(global_attn_mask)
707-
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
708-
local_attn_mask = torch.where(
709-
local_attn_mask == 0, global_attn_mask, float("-inf")
710-
)
711-
local_attn_masks.append(local_attn_mask)
712-
713-
return {
714-
"has_images": True,
715-
"seq_lens": seq_lens,
716-
"global_attn_masks": global_attn_masks,
717-
"local_attn_masks": local_attn_masks,
718-
}
719-
720-
def prepare_attn_masks(
721-
self,
722-
input_ids: torch.Tensor,
723-
positions: torch.Tensor,
724-
mask_dtype: torch.dtype,
725-
**kwargs,
726-
):
727-
kwargs["has_images"] = True
728-
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
729-
# This is a HACK. Fix this.
730-
start_indices = (positions == 0).cpu().nonzero()
731-
num_seqs = len(start_indices)
732-
seq_lens = []
733-
for i in range(num_seqs):
734-
start_idx = start_indices[i].item()
735-
if i < num_seqs - 1:
736-
end_idx = start_indices[i + 1].item()
737-
else:
738-
end_idx = len(input_ids)
739-
seq_lens.append(end_idx - start_idx)
740-
kwargs["seq_lens"] = seq_lens
741-
742-
global_attn_masks = []
743-
local_attn_masks = []
744-
start_idx = 0
745-
for seq_len in seq_lens:
746-
end_idx = start_idx + seq_len
747-
input_token_ids = input_ids[start_idx:end_idx]
748-
start_idx = end_idx
749-
# Create a global causal mask.
750-
global_attn_mask = torch.empty(
751-
1,
752-
1,
753-
seq_len,
754-
seq_len,
755-
dtype=mask_dtype,
756-
device=input_ids.device,
757-
)
758-
global_attn_mask.fill_(float("-inf"))
759-
# Fill the lower triangle with 0.
760-
global_attn_mask = global_attn_mask.triu(diagonal=1)
761-
762-
# Consider the bidirectional attention between image tokens.
763-
img_mask = torch.zeros_like(global_attn_mask)
764-
img_pos = input_token_ids == self.config.image_token_index
765-
img_mask[:, :, :, img_pos] += 1
766-
img_mask[:, :, img_pos, :] += 1
767-
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
768-
global_attn_masks.append(global_attn_mask)
769-
770-
sliding_window = self.config.text_config.sliding_window
771-
if sliding_window is not None:
772-
# Create a local causal mask with sliding window (1024).
773-
local_attn_mask = torch.ones_like(global_attn_mask)
774-
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
775-
local_attn_mask = torch.where(
776-
local_attn_mask == 0, global_attn_mask, float("-inf")
777-
)
778-
local_attn_masks.append(local_attn_mask)
779-
kwargs["global_attn_masks"] = global_attn_masks
780-
kwargs["local_attn_masks"] = local_attn_masks
781-
return kwargs
782-
783647
def compute_logits(
784648
self,
785649
hidden_states: torch.Tensor,

vllm/transformers_utils/config.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -520,17 +520,6 @@ def is_interleaved(config: PretrainedConfig) -> bool:
520520
return False
521521

522522

523-
def uses_custom_attention_masks(config: PretrainedConfig) -> bool:
524-
"""Detect if model uses custom attention mask generation for multimodal.
525-
526-
Some multimodal models require custom attention masks that enable
527-
bidirectional attention between image tokens while maintaining causal
528-
attention for text tokens. Currently applies to Gemma3 multimodal models.
529-
"""
530-
architectures = getattr(config, "architectures", [])
531-
return "Gemma3ForConditionalGeneration" in architectures
532-
533-
534523
def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str):
535524
"""
536525
Update kwargs for AutoConfig initialization based on model_type

vllm/v1/worker/gpu_model_runner.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,6 @@ def __init__(
324324
# Multi-modal data support
325325
self.mm_registry = MULTIMODAL_REGISTRY
326326
self.uses_mrope = model_config.uses_mrope
327-
self.uses_custom_attention_masks = model_config.uses_custom_attention_masks
328327
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
329328
model_config
330329
)
@@ -2352,24 +2351,6 @@ def _preprocess(
23522351
**self._init_model_kwargs(num_scheduled_tokens),
23532352
**self._extract_mm_kwargs(scheduler_output),
23542353
}
2355-
2356-
# Generate custom attention masks for models that require them.
2357-
# V1 pre-generates embeddings, so forward() skips prepare_attn_masks().
2358-
# Check mm_features (mm_embeds is empty during decode).
2359-
has_mm_features = any(
2360-
req_state.mm_features for req_state in self.requests.values()
2361-
)
2362-
if (
2363-
self.uses_custom_attention_masks
2364-
and has_mm_features
2365-
and hasattr(self.model, "generate_attention_masks")
2366-
):
2367-
mask_kwargs = self.model.generate_attention_masks(
2368-
self.input_ids.gpu[:num_scheduled_tokens],
2369-
self.positions.gpu[:num_scheduled_tokens],
2370-
mask_dtype=self.model.dtype,
2371-
)
2372-
model_kwargs.update(mask_kwargs)
23732354
elif self.enable_prompt_embeds and is_first_rank:
23742355
# Get the input embeddings for the tokens that are not input embeds,
23752356
# then put them into the appropriate positions.

0 commit comments

Comments
 (0)