@@ -845,7 +845,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
845845 """
846846 import numpy as np
847847
848- def _n2p (w , t = True ):
848+ def _n2p (w , t = True , idx = None ):
849+ if idx is not None :
850+ w = w [idx ]
849851 if w .ndim == 4 and w .shape [0 ] == w .shape [1 ] == w .shape [2 ] == 1 :
850852 w = w .flatten ()
851853 if t :
@@ -955,21 +957,28 @@ def _n2p(w, t=True):
955957
956958 mha_sub , b_sub , ln1_sub = (0 , 0 , 1 ) if big_vision else (1 , 3 , 2 )
957959 for i , block in enumerate (model .blocks .children ()):
958- block_prefix = f'{ prefix } Transformer/encoderblock_{ i } /'
960+ if f'{ prefix } Transformer/encoderblock/LayerNorm_0/scale' in w :
961+ block_prefix = f'{ prefix } Transformer/encoderblock/'
962+ idx = i
963+ else :
964+ block_prefix = f'{ prefix } Transformer/encoderblock_{ i } /'
965+ idx = None
959966 mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{ mha_sub } /'
960- block .norm1 .weight .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_0/scale' ]))
961- block .norm1 .bias .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_0/bias' ]))
967+ block .norm1 .weight .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_0/scale' ], idx = idx ))
968+ block .norm1 .bias .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_0/bias' ], idx = idx ))
962969 block .attn .qkv .weight .copy_ (torch .cat ([
963- _n2p (w [f'{ mha_prefix } { n } /kernel' ], t = False ).flatten (1 ).T for n in ('query' , 'key' , 'value' )]))
970+ _n2p (w [f'{ mha_prefix } { n } /kernel' ], t = False , idx = idx ).flatten (1 ).T for n in ('query' , 'key' , 'value' )]))
964971 block .attn .qkv .bias .copy_ (torch .cat ([
965- _n2p (w [f'{ mha_prefix } { n } /bias' ], t = False ).reshape (- 1 ) for n in ('query' , 'key' , 'value' )]))
966- block .attn .proj .weight .copy_ (_n2p (w [f'{ mha_prefix } out/kernel' ]).flatten (1 ))
967- block .attn .proj .bias .copy_ (_n2p (w [f'{ mha_prefix } out/bias' ]))
968- block .norm2 .weight .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_{ ln1_sub } /scale' ]))
969- block .norm2 .bias .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_{ ln1_sub } /bias' ]))
972+ _n2p (w [f'{ mha_prefix } { n } /bias' ], t = False , idx = idx ).reshape (- 1 ) for n in ('query' , 'key' , 'value' )]))
973+ block .attn .proj .weight .copy_ (_n2p (w [f'{ mha_prefix } out/kernel' ], idx = idx ).flatten (1 ))
974+ block .attn .proj .bias .copy_ (_n2p (w [f'{ mha_prefix } out/bias' ], idx = idx ))
975+ block .norm2 .weight .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_{ ln1_sub } /scale' ], idx = idx ))
976+ block .norm2 .bias .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_{ ln1_sub } /bias' ], idx = idx ))
970977 for r in range (2 ):
971- getattr (block .mlp , f'fc{ r + 1 } ' ).weight .copy_ (_n2p (w [f'{ block_prefix } MlpBlock_{ b_sub } /Dense_{ r } /kernel' ]))
972- getattr (block .mlp , f'fc{ r + 1 } ' ).bias .copy_ (_n2p (w [f'{ block_prefix } MlpBlock_{ b_sub } /Dense_{ r } /bias' ]))
978+ getattr (block .mlp , f'fc{ r + 1 } ' ).weight .copy_ (
979+ _n2p (w [f'{ block_prefix } MlpBlock_{ b_sub } /Dense_{ r } /kernel' ], idx = idx ))
980+ getattr (block .mlp , f'fc{ r + 1 } ' ).bias .copy_ (
981+ _n2p (w [f'{ block_prefix } MlpBlock_{ b_sub } /Dense_{ r } /bias' ], idx = idx ))
973982
974983
975984def _convert_openai_clip (
@@ -1769,6 +1778,44 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
17691778 input_size = (3 , 384 , 384 ),
17701779 num_classes = 0 ),
17711780
1781+ 'vit_so400m_patch14_siglip_gap_224.webli' : _cfg (
1782+ hf_hub_id = 'timm/ViT-SO400M-14-SigLIP' ,
1783+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1784+ num_classes = 0 ),
1785+ 'vit_so400m_patch14_siglip_gap_224.pali_mix' : _cfg (
1786+ hf_hub_id = 'google/paligemma-3b-mix-224-jax' ,
1787+ hf_hub_filename = 'paligemma-3b-mix-224.npz' ,
1788+ custom_load = 'hf' ,
1789+ num_classes = 0 ),
1790+ 'vit_so400m_patch14_siglip_gap_224.pali_pt' : _cfg (
1791+ hf_hub_id = 'google/paligemma-3b-pt-224-jax' ,
1792+ hf_hub_filename = 'paligemma-3b-pt-224.npz' ,
1793+ custom_load = 'hf' ,
1794+ num_classes = 0 ),
1795+ 'vit_so400m_patch14_siglip_gap_384.webli' : _cfg (
1796+ hf_hub_id = 'timm/ViT-SO400M-14-SigLIP-384' ,
1797+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1798+ input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
1799+ num_classes = 0 ),
1800+ 'vit_so400m_patch14_siglip_gap_448.pali_mix' : _cfg (
1801+ hf_hub_id = 'google/paligemma-3b-mix-448-jax' ,
1802+ hf_hub_filename = 'paligemma-3b-mix-448.npz' ,
1803+ custom_load = 'hf' ,
1804+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1805+ num_classes = 0 ),
1806+ 'vit_so400m_patch14_siglip_gap_448.pali_pt' : _cfg (
1807+ hf_hub_id = 'google/paligemma-3b-pt-448-jax' ,
1808+ hf_hub_filename = 'paligemma-3b-pt-448.npz' ,
1809+ custom_load = 'hf' ,
1810+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1811+ num_classes = 0 ),
1812+ 'vit_so400m_patch14_siglip_gap_896.pali_pt' : _cfg (
1813+ hf_hub_id = 'google/paligemma-3b-pt-896-jax' ,
1814+ hf_hub_filename = 'paligemma-3b-pt-896.npz' ,
1815+ custom_load = 'hf' ,
1816+ input_size = (3 , 896 , 896 ), crop_pct = 1.0 ,
1817+ num_classes = 0 ),
1818+
17721819 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m' : _cfg (
17731820 hf_hub_id = 'timm/' ,
17741821 hf_hub_filename = 'open_clip_pytorch_model.bin' ,
@@ -2756,15 +2803,48 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
27562803 return model
27572804
27582805
2759- # @register_model
2760- # def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2761- # model_args = dict(
2762- # patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
2763- # no_embed_class=True, reg_tokens=4,
2764- # )
2765- # model = _create_vision_transformer(
2766- # 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
2767- # return model
2806+ @register_model
2807+ def vit_so400m_patch14_siglip_gap_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2808+ model_args = dict (
2809+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 ,
2810+ class_token = False , global_pool = 'avg' , fc_norm = False ,
2811+ )
2812+ model = _create_vision_transformer (
2813+ 'vit_so400m_patch14_siglip_gap_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2814+ return model
2815+
2816+
2817+ @register_model
2818+ def vit_so400m_patch14_siglip_gap_384 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2819+ model_args = dict (
2820+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 ,
2821+ class_token = False , global_pool = 'avg' , fc_norm = False ,
2822+ )
2823+ model = _create_vision_transformer (
2824+ 'vit_so400m_patch14_siglip_gap_384' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2825+ return model
2826+
2827+
2828+ @register_model
2829+ def vit_so400m_patch14_siglip_gap_448 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2830+ model_args = dict (
2831+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 ,
2832+ class_token = False , global_pool = 'avg' , fc_norm = False ,
2833+ )
2834+ model = _create_vision_transformer (
2835+ 'vit_so400m_patch14_siglip_gap_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2836+ return model
2837+
2838+
2839+ @register_model
2840+ def vit_so400m_patch14_siglip_gap_896 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2841+ model_args = dict (
2842+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 ,
2843+ class_token = False , global_pool = 'avg' , fc_norm = False ,
2844+ )
2845+ model = _create_vision_transformer (
2846+ 'vit_so400m_patch14_siglip_gap_896' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2847+ return model
27682848
27692849
27702850@register_model
0 commit comments