Skip to content

Commit fd36275

Browse files
authored
handle inputs from Siglip/Siglip2 non-automapped encoder layers (#41930)
* handle inputs from non-automapped encoder layers * correct inheritance + protect executorch * fixup * fix tests * missing docstring * attn support * fix initialization * reorder/simplify * flag test as broken * minor changes * modulaaar
1 parent 922e854 commit fd36275

File tree

4 files changed

+118
-99
lines changed

4 files changed

+118
-99
lines changed

src/transformers/models/siglip/modeling_siglip.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,14 @@ def forward(
678678
)
679679

680680

681-
class SiglipVisionTransformer(nn.Module):
681+
class SiglipVisionTransformer(SiglipPreTrainedModel):
682+
_can_record_outputs = {
683+
"hidden_states": SiglipEncoderLayer,
684+
"attentions": SiglipAttention,
685+
}
686+
682687
def __init__(self, config: SiglipVisionConfig):
683-
super().__init__()
688+
super().__init__(config)
684689
self.config = config
685690
embed_dim = config.hidden_size
686691

@@ -691,6 +696,7 @@ def __init__(self, config: SiglipVisionConfig):
691696
if self.use_head:
692697
self.head = SiglipMultiheadAttentionPoolingHead(config)
693698

699+
@check_model_inputs(tie_last_hidden_states=False)
694700
@auto_docstring
695701
def forward(
696702
self,

src/transformers/models/siglip2/modeling_siglip2.py

Lines changed: 99 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -349,99 +349,6 @@ def forward(
349349
return hidden_states
350350

351351

352-
class Siglip2Encoder(nn.Module):
353-
"""
354-
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
355-
[`Siglip2EncoderLayer`].
356-
357-
Args:
358-
config: Siglip2Config
359-
"""
360-
361-
def __init__(self, config: Siglip2Config):
362-
super().__init__()
363-
self.config = config
364-
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
365-
self.gradient_checkpointing = False
366-
367-
# Ignore copy
368-
@auto_docstring
369-
def forward(
370-
self,
371-
inputs_embeds,
372-
attention_mask: Optional[torch.Tensor] = None,
373-
**kwargs: Unpack[TransformersKwargs],
374-
) -> BaseModelOutput:
375-
hidden_states = inputs_embeds
376-
for encoder_layer in self.layers:
377-
hidden_states = encoder_layer(
378-
hidden_states,
379-
attention_mask,
380-
**kwargs,
381-
)
382-
383-
return BaseModelOutput(last_hidden_state=hidden_states)
384-
385-
386-
class Siglip2VisionTransformer(nn.Module):
387-
def __init__(self, config: Siglip2VisionConfig):
388-
super().__init__()
389-
self.config = config
390-
embed_dim = config.hidden_size
391-
392-
self.embeddings = Siglip2VisionEmbeddings(config)
393-
self.encoder = Siglip2Encoder(config)
394-
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
395-
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
396-
if self.use_head:
397-
self.head = Siglip2MultiheadAttentionPoolingHead(config)
398-
399-
@auto_docstring
400-
def forward(
401-
self,
402-
pixel_values: torch.FloatTensor,
403-
attention_mask: torch.Tensor,
404-
spatial_shapes: torch.LongTensor,
405-
output_attentions: Optional[bool] = None,
406-
output_hidden_states: Optional[bool] = None,
407-
) -> BaseModelOutputWithPooling:
408-
r"""
409-
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
410-
Tensor containing the spatial dimensions (height, width) of the input images.
411-
"""
412-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
413-
output_hidden_states = (
414-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
415-
)
416-
417-
hidden_states = self.embeddings(pixel_values, spatial_shapes)
418-
419-
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
420-
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
421-
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
422-
else:
423-
encoder_attention_mask = attention_mask
424-
425-
encoder_outputs: BaseModelOutput = self.encoder(
426-
inputs_embeds=hidden_states,
427-
attention_mask=encoder_attention_mask,
428-
output_attentions=output_attentions,
429-
output_hidden_states=output_hidden_states,
430-
)
431-
432-
last_hidden_state = encoder_outputs.last_hidden_state
433-
last_hidden_state = self.post_layernorm(last_hidden_state)
434-
435-
pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
436-
437-
return BaseModelOutputWithPooling(
438-
last_hidden_state=last_hidden_state,
439-
pooler_output=pooler_output,
440-
hidden_states=encoder_outputs.hidden_states,
441-
attentions=encoder_outputs.attentions,
442-
)
443-
444-
445352
def _trunc_normal_(tensor, mean, std, a, b):
446353
# Cut & paste from PyTorch official master until it's in a few official releases - RW
447354
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@@ -607,6 +514,105 @@ def _init_weights(self, module):
607514
module.weight.data.fill_(1.0)
608515

609516

517+
class Siglip2Encoder(nn.Module):
518+
"""
519+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
520+
[`Siglip2EncoderLayer`].
521+
522+
Args:
523+
config: Siglip2Config
524+
"""
525+
526+
def __init__(self, config: Siglip2Config):
527+
super().__init__()
528+
self.config = config
529+
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
530+
self.gradient_checkpointing = False
531+
532+
# Ignore copy
533+
@auto_docstring
534+
def forward(
535+
self,
536+
inputs_embeds,
537+
attention_mask: Optional[torch.Tensor] = None,
538+
**kwargs: Unpack[TransformersKwargs],
539+
) -> BaseModelOutput:
540+
hidden_states = inputs_embeds
541+
for encoder_layer in self.layers:
542+
hidden_states = encoder_layer(
543+
hidden_states,
544+
attention_mask,
545+
**kwargs,
546+
)
547+
548+
return BaseModelOutput(last_hidden_state=hidden_states)
549+
550+
551+
class Siglip2VisionTransformer(Siglip2PreTrainedModel):
552+
_can_record_outputs = {
553+
"hidden_states": Siglip2EncoderLayer,
554+
"attentions": Siglip2Attention,
555+
}
556+
557+
def __init__(self, config: Siglip2VisionConfig):
558+
super().__init__(config)
559+
self.config = config
560+
embed_dim = config.hidden_size
561+
562+
self.embeddings = Siglip2VisionEmbeddings(config)
563+
self.encoder = Siglip2Encoder(config)
564+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
565+
self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
566+
if self.use_head:
567+
self.head = Siglip2MultiheadAttentionPoolingHead(config)
568+
569+
@check_model_inputs(tie_last_hidden_states=False)
570+
@auto_docstring
571+
def forward(
572+
self,
573+
pixel_values: torch.FloatTensor,
574+
attention_mask: torch.Tensor,
575+
spatial_shapes: torch.LongTensor,
576+
output_attentions: Optional[bool] = None,
577+
output_hidden_states: Optional[bool] = None,
578+
) -> BaseModelOutputWithPooling:
579+
r"""
580+
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
581+
Tensor containing the spatial dimensions (height, width) of the input images.
582+
"""
583+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
584+
output_hidden_states = (
585+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
586+
)
587+
588+
hidden_states = self.embeddings(pixel_values, spatial_shapes)
589+
590+
if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
591+
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
592+
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
593+
else:
594+
encoder_attention_mask = attention_mask
595+
596+
encoder_outputs: BaseModelOutput = self.encoder(
597+
inputs_embeds=hidden_states,
598+
attention_mask=encoder_attention_mask,
599+
output_attentions=output_attentions,
600+
output_hidden_states=output_hidden_states,
601+
)
602+
603+
last_hidden_state = encoder_outputs.last_hidden_state
604+
last_hidden_state = self.post_layernorm(last_hidden_state)
605+
606+
pooler_output = self.head(last_hidden_state, attention_mask) if self.use_head else None
607+
608+
return BaseModelOutputWithPooling(
609+
last_hidden_state=last_hidden_state,
610+
pooler_output=pooler_output,
611+
hidden_states=encoder_outputs.hidden_states,
612+
attentions=encoder_outputs.attentions,
613+
)
614+
615+
610616
class Siglip2TextEmbeddings(nn.Module):
611617
def __init__(self, config: Siglip2TextConfig):
612618
super().__init__()

src/transformers/models/siglip2/modular_siglip2.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
3939
from ...utils import auto_docstring, filter_out_non_signature_kwargs
40+
from ...utils.generic import check_model_inputs
4041

4142

4243
class Siglip2TextConfig(SiglipTextConfig):
@@ -230,6 +231,10 @@ def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTen
230231
return embeddings
231232

232233

234+
class Siglip2PreTrainedModel(SiglipPreTrainedModel):
235+
pass
236+
237+
233238
class Siglip2VisionTransformer(SiglipVisionTransformer):
234239
def __init__(self, config: Siglip2VisionConfig):
235240
super().__init__(config)
@@ -280,10 +285,6 @@ def forward(
280285
)
281286

282287

283-
class Siglip2PreTrainedModel(SiglipPreTrainedModel):
284-
pass
285-
286-
287288
class Siglip2TextModel(SiglipTextModel):
288289
pass
289290

@@ -314,6 +315,8 @@ def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Ten
314315

315316
class Siglip2VisionModel(SiglipVisionModel):
316317
# Update: add `spatial_shapes` and `pixel_attention_mask`
318+
@check_model_inputs(tie_last_hidden_states=False)
319+
@auto_docstring
317320
def forward(
318321
self,
319322
pixel_values: torch.FloatTensor,

utils/check_repo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@
9090
"Kosmos2_5TextForCausalLM",
9191
"Kosmos2_5VisionModel",
9292
"SmolVLMVisionTransformer",
93+
"SiglipVisionTransformer",
94+
"Siglip2VisionTransformer",
9395
"AriaTextForCausalLM",
9496
"AriaTextModel",
9597
"Phi4MultimodalAudioModel",
@@ -358,7 +360,9 @@
358360
"SegGptForImageSegmentation",
359361
"SiglipVisionModel",
360362
"SiglipTextModel",
363+
"SiglipVisionTransformer",
361364
"Siglip2VisionModel",
365+
"Siglip2VisionTransformer",
362366
"Siglip2TextModel",
363367
"ChameleonVQVAE", # no autoclass for VQ-VAE models
364368
"VitPoseForPoseEstimation",

0 commit comments

Comments
 (0)