@@ -349,99 +349,6 @@ def forward(
349349 return hidden_states
350350
351351
352- class Siglip2Encoder (nn .Module ):
353- """
354- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
355- [`Siglip2EncoderLayer`].
356-
357- Args:
358- config: Siglip2Config
359- """
360-
361- def __init__ (self , config : Siglip2Config ):
362- super ().__init__ ()
363- self .config = config
364- self .layers = nn .ModuleList ([Siglip2EncoderLayer (config ) for _ in range (config .num_hidden_layers )])
365- self .gradient_checkpointing = False
366-
367- # Ignore copy
368- @auto_docstring
369- def forward (
370- self ,
371- inputs_embeds ,
372- attention_mask : Optional [torch .Tensor ] = None ,
373- ** kwargs : Unpack [TransformersKwargs ],
374- ) -> BaseModelOutput :
375- hidden_states = inputs_embeds
376- for encoder_layer in self .layers :
377- hidden_states = encoder_layer (
378- hidden_states ,
379- attention_mask ,
380- ** kwargs ,
381- )
382-
383- return BaseModelOutput (last_hidden_state = hidden_states )
384-
385-
386- class Siglip2VisionTransformer (nn .Module ):
387- def __init__ (self , config : Siglip2VisionConfig ):
388- super ().__init__ ()
389- self .config = config
390- embed_dim = config .hidden_size
391-
392- self .embeddings = Siglip2VisionEmbeddings (config )
393- self .encoder = Siglip2Encoder (config )
394- self .post_layernorm = nn .LayerNorm (embed_dim , eps = config .layer_norm_eps )
395- self .use_head = True if not hasattr (config , "vision_use_head" ) else config .vision_use_head
396- if self .use_head :
397- self .head = Siglip2MultiheadAttentionPoolingHead (config )
398-
399- @auto_docstring
400- def forward (
401- self ,
402- pixel_values : torch .FloatTensor ,
403- attention_mask : torch .Tensor ,
404- spatial_shapes : torch .LongTensor ,
405- output_attentions : Optional [bool ] = None ,
406- output_hidden_states : Optional [bool ] = None ,
407- ) -> BaseModelOutputWithPooling :
408- r"""
409- spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
410- Tensor containing the spatial dimensions (height, width) of the input images.
411- """
412- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
413- output_hidden_states = (
414- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
415- )
416-
417- hidden_states = self .embeddings (pixel_values , spatial_shapes )
418-
419- if attention_mask is not None and self .config ._attn_implementation != "flash_attention_2" :
420- # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
421- encoder_attention_mask = _prepare_4d_attention_mask (attention_mask , hidden_states .dtype )
422- else :
423- encoder_attention_mask = attention_mask
424-
425- encoder_outputs : BaseModelOutput = self .encoder (
426- inputs_embeds = hidden_states ,
427- attention_mask = encoder_attention_mask ,
428- output_attentions = output_attentions ,
429- output_hidden_states = output_hidden_states ,
430- )
431-
432- last_hidden_state = encoder_outputs .last_hidden_state
433- last_hidden_state = self .post_layernorm (last_hidden_state )
434-
435- pooler_output = self .head (last_hidden_state , attention_mask ) if self .use_head else None
436-
437- return BaseModelOutputWithPooling (
438- last_hidden_state = last_hidden_state ,
439- pooler_output = pooler_output ,
440- hidden_states = encoder_outputs .hidden_states ,
441- attentions = encoder_outputs .attentions ,
442- )
443-
444-
445352def _trunc_normal_ (tensor , mean , std , a , b ):
446353 # Cut & paste from PyTorch official master until it's in a few official releases - RW
447354 # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
@@ -607,6 +514,105 @@ def _init_weights(self, module):
607514 module .weight .data .fill_ (1.0 )
608515
609516
517+ class Siglip2Encoder (nn .Module ):
518+ """
519+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
520+ [`Siglip2EncoderLayer`].
521+
522+ Args:
523+ config: Siglip2Config
524+ """
525+
526+ def __init__ (self , config : Siglip2Config ):
527+ super ().__init__ ()
528+ self .config = config
529+ self .layers = nn .ModuleList ([Siglip2EncoderLayer (config ) for _ in range (config .num_hidden_layers )])
530+ self .gradient_checkpointing = False
531+
532+ # Ignore copy
533+ @auto_docstring
534+ def forward (
535+ self ,
536+ inputs_embeds ,
537+ attention_mask : Optional [torch .Tensor ] = None ,
538+ ** kwargs : Unpack [TransformersKwargs ],
539+ ) -> BaseModelOutput :
540+ hidden_states = inputs_embeds
541+ for encoder_layer in self .layers :
542+ hidden_states = encoder_layer (
543+ hidden_states ,
544+ attention_mask ,
545+ ** kwargs ,
546+ )
547+
548+ return BaseModelOutput (last_hidden_state = hidden_states )
549+
550+
551+ class Siglip2VisionTransformer (Siglip2PreTrainedModel ):
552+ _can_record_outputs = {
553+ "hidden_states" : Siglip2EncoderLayer ,
554+ "attentions" : Siglip2Attention ,
555+ }
556+
557+ def __init__ (self , config : Siglip2VisionConfig ):
558+ super ().__init__ (config )
559+ self .config = config
560+ embed_dim = config .hidden_size
561+
562+ self .embeddings = Siglip2VisionEmbeddings (config )
563+ self .encoder = Siglip2Encoder (config )
564+ self .post_layernorm = nn .LayerNorm (embed_dim , eps = config .layer_norm_eps )
565+ self .use_head = True if not hasattr (config , "vision_use_head" ) else config .vision_use_head
566+ if self .use_head :
567+ self .head = Siglip2MultiheadAttentionPoolingHead (config )
568+
569+ @check_model_inputs (tie_last_hidden_states = False )
570+ @auto_docstring
571+ def forward (
572+ self ,
573+ pixel_values : torch .FloatTensor ,
574+ attention_mask : torch .Tensor ,
575+ spatial_shapes : torch .LongTensor ,
576+ output_attentions : Optional [bool ] = None ,
577+ output_hidden_states : Optional [bool ] = None ,
578+ ) -> BaseModelOutputWithPooling :
579+ r"""
580+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
581+ Tensor containing the spatial dimensions (height, width) of the input images.
582+ """
583+ output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
584+ output_hidden_states = (
585+ output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
586+ )
587+
588+ hidden_states = self .embeddings (pixel_values , spatial_shapes )
589+
590+ if attention_mask is not None and self .config ._attn_implementation != "flash_attention_2" :
591+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
592+ encoder_attention_mask = _prepare_4d_attention_mask (attention_mask , hidden_states .dtype )
593+ else :
594+ encoder_attention_mask = attention_mask
595+
596+ encoder_outputs : BaseModelOutput = self .encoder (
597+ inputs_embeds = hidden_states ,
598+ attention_mask = encoder_attention_mask ,
599+ output_attentions = output_attentions ,
600+ output_hidden_states = output_hidden_states ,
601+ )
602+
603+ last_hidden_state = encoder_outputs .last_hidden_state
604+ last_hidden_state = self .post_layernorm (last_hidden_state )
605+
606+ pooler_output = self .head (last_hidden_state , attention_mask ) if self .use_head else None
607+
608+ return BaseModelOutputWithPooling (
609+ last_hidden_state = last_hidden_state ,
610+ pooler_output = pooler_output ,
611+ hidden_states = encoder_outputs .hidden_states ,
612+ attentions = encoder_outputs .attentions ,
613+ )
614+
615+
610616class Siglip2TextEmbeddings (nn .Module ):
611617 def __init__ (self , config : Siglip2TextConfig ):
612618 super ().__init__ ()
0 commit comments