@@ -596,7 +596,7 @@ def _process_image_input(
596596 def get_language_model (self ) -> torch .nn .Module :
597597 return self .language_model
598598
599- def get_multimodal_embeddings (self , ** kwargs : object ) -> MultiModalEmbeddings :
599+ def embed_multimodal (self , ** kwargs : object ) -> MultiModalEmbeddings :
600600 image_input = self ._parse_and_validate_image_input (** kwargs )
601601 if image_input is None :
602602 return []
@@ -644,142 +644,6 @@ def forward(
644644
645645 return hidden_states
646646
647- def generate_attention_masks (
648- self ,
649- input_ids : torch .Tensor ,
650- positions : torch .Tensor ,
651- mask_dtype : torch .dtype ,
652- ) -> dict [str , Any ]:
653- """Generate custom attention masks for Gemma3 multimodal inputs.
654-
655- This is called by V1 engine's gpu_model_runner during preprocessing
656- to generate attention masks that allow bidirectional attention between
657- image tokens while maintaining causal attention for text.
658- """
659- # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
660- # This is a HACK. Fix this.
661- start_indices = (positions == 0 ).cpu ().nonzero ()
662- num_seqs = len (start_indices )
663- seq_lens = []
664- for i in range (num_seqs ):
665- start_idx = start_indices [i ]
666- end_idx = start_indices [i + 1 ] if i < num_seqs - 1 else len (input_ids )
667- seq_lens .append (end_idx - start_idx )
668-
669- global_attn_masks = []
670- local_attn_masks = []
671- start_idx = 0
672- for seq_idx , seq_len in enumerate (seq_lens ):
673- end_idx = start_idx + seq_len
674- input_token_ids = input_ids [start_idx :end_idx ]
675-
676- # Find image token positions
677- img_pos = input_token_ids == self .config .image_token_index
678-
679- start_idx = end_idx
680-
681- # Create a global causal mask
682- global_attn_mask = torch .empty (
683- 1 ,
684- 1 ,
685- seq_len ,
686- seq_len ,
687- dtype = mask_dtype ,
688- device = input_ids .device ,
689- )
690- global_attn_mask .fill_ (float ("-inf" ))
691- # Fill the lower triangle with 0 (causal attention)
692- global_attn_mask = global_attn_mask .triu (diagonal = 1 )
693-
694- # Enable bidirectional attention between image tokens
695- img_mask = torch .zeros_like (global_attn_mask )
696- img_mask [:, :, :, img_pos ] += 1
697- img_mask [:, :, img_pos , :] += 1
698- global_attn_mask = torch .where (img_mask == 2 , 0 , global_attn_mask )
699- global_attn_masks .append (global_attn_mask )
700-
701- # GGUF compatibility: config might be Gemma3TextConfig directly
702- text_config = getattr (self .config , "text_config" , self .config )
703- sliding_window = text_config .sliding_window
704- if sliding_window is not None :
705- # Create a local causal mask with sliding window (1024)
706- local_attn_mask = torch .ones_like (global_attn_mask )
707- local_attn_mask = torch .tril (local_attn_mask , diagonal = - sliding_window )
708- local_attn_mask = torch .where (
709- local_attn_mask == 0 , global_attn_mask , float ("-inf" )
710- )
711- local_attn_masks .append (local_attn_mask )
712-
713- return {
714- "has_images" : True ,
715- "seq_lens" : seq_lens ,
716- "global_attn_masks" : global_attn_masks ,
717- "local_attn_masks" : local_attn_masks ,
718- }
719-
720- def prepare_attn_masks (
721- self ,
722- input_ids : torch .Tensor ,
723- positions : torch .Tensor ,
724- mask_dtype : torch .dtype ,
725- ** kwargs ,
726- ):
727- kwargs ["has_images" ] = True
728- # NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
729- # This is a HACK. Fix this.
730- start_indices = (positions == 0 ).cpu ().nonzero ()
731- num_seqs = len (start_indices )
732- seq_lens = []
733- for i in range (num_seqs ):
734- start_idx = start_indices [i ].item ()
735- if i < num_seqs - 1 :
736- end_idx = start_indices [i + 1 ].item ()
737- else :
738- end_idx = len (input_ids )
739- seq_lens .append (end_idx - start_idx )
740- kwargs ["seq_lens" ] = seq_lens
741-
742- global_attn_masks = []
743- local_attn_masks = []
744- start_idx = 0
745- for seq_len in seq_lens :
746- end_idx = start_idx + seq_len
747- input_token_ids = input_ids [start_idx :end_idx ]
748- start_idx = end_idx
749- # Create a global causal mask.
750- global_attn_mask = torch .empty (
751- 1 ,
752- 1 ,
753- seq_len ,
754- seq_len ,
755- dtype = mask_dtype ,
756- device = input_ids .device ,
757- )
758- global_attn_mask .fill_ (float ("-inf" ))
759- # Fill the lower triangle with 0.
760- global_attn_mask = global_attn_mask .triu (diagonal = 1 )
761-
762- # Consider the bidirectional attention between image tokens.
763- img_mask = torch .zeros_like (global_attn_mask )
764- img_pos = input_token_ids == self .config .image_token_index
765- img_mask [:, :, :, img_pos ] += 1
766- img_mask [:, :, img_pos , :] += 1
767- global_attn_mask = torch .where (img_mask == 2 , 0 , global_attn_mask )
768- global_attn_masks .append (global_attn_mask )
769-
770- sliding_window = self .config .text_config .sliding_window
771- if sliding_window is not None :
772- # Create a local causal mask with sliding window (1024).
773- local_attn_mask = torch .ones_like (global_attn_mask )
774- local_attn_mask = torch .tril (local_attn_mask , diagonal = - sliding_window )
775- local_attn_mask = torch .where (
776- local_attn_mask == 0 , global_attn_mask , float ("-inf" )
777- )
778- local_attn_masks .append (local_attn_mask )
779- kwargs ["global_attn_masks" ] = global_attn_masks
780- kwargs ["local_attn_masks" ] = local_attn_masks
781- return kwargs
782-
783647 def compute_logits (
784648 self ,
785649 hidden_states : torch .Tensor ,
0 commit comments