4444 OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
4545from timm .layers import PatchEmbed , Mlp , DropPath , AttentionPoolLatent , RmsNorm , PatchDropout , SwiGLUPacked , \
4646 trunc_normal_ , lecun_normal_ , resample_patch_embed , resample_abs_pos_embed , use_fused_attn , \
47- get_act_layer , get_norm_layer , LayerType
47+ SwiGLU , get_act_layer , get_norm_layer , LayerType
4848from ._builder import build_model_with_cfg
4949from ._features import feature_take_indices
5050from ._manipulate import named_apply , checkpoint_seq , adapt_input_conv
@@ -65,6 +65,7 @@ def __init__(
6565 num_heads : int = 8 ,
6666 qkv_bias : bool = False ,
6767 qk_norm : bool = False ,
68+ proj_bias : bool = True ,
6869 attn_drop : float = 0. ,
6970 proj_drop : float = 0. ,
7071 norm_layer : nn .Module = nn .LayerNorm ,
@@ -80,7 +81,7 @@ def __init__(
8081 self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
8182 self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
8283 self .attn_drop = nn .Dropout (attn_drop )
83- self .proj = nn .Linear (dim , dim )
84+ self .proj = nn .Linear (dim , dim , bias = proj_bias )
8485 self .proj_drop = nn .Dropout (proj_drop )
8586
8687 def forward (self , x : torch .Tensor ) -> torch .Tensor :
@@ -130,6 +131,7 @@ def __init__(
130131 mlp_ratio : float = 4. ,
131132 qkv_bias : bool = False ,
132133 qk_norm : bool = False ,
134+ proj_bias : bool = True ,
133135 proj_drop : float = 0. ,
134136 attn_drop : float = 0. ,
135137 init_values : Optional [float ] = None ,
@@ -145,6 +147,7 @@ def __init__(
145147 num_heads = num_heads ,
146148 qkv_bias = qkv_bias ,
147149 qk_norm = qk_norm ,
150+ proj_bias = proj_bias ,
148151 attn_drop = attn_drop ,
149152 proj_drop = proj_drop ,
150153 norm_layer = norm_layer ,
@@ -157,6 +160,7 @@ def __init__(
157160 in_features = dim ,
158161 hidden_features = int (dim * mlp_ratio ),
159162 act_layer = act_layer ,
163+ bias = proj_bias ,
160164 drop = proj_drop ,
161165 )
162166 self .ls2 = LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()
@@ -176,6 +180,7 @@ def __init__(
176180 mlp_ratio : float = 4. ,
177181 qkv_bias : bool = False ,
178182 qk_norm : bool = False ,
183+ proj_bias : bool = True ,
179184 proj_drop : float = 0. ,
180185 attn_drop : float = 0. ,
181186 init_values : Optional [float ] = None ,
@@ -192,6 +197,7 @@ def __init__(
192197 num_heads = num_heads ,
193198 qkv_bias = qkv_bias ,
194199 qk_norm = qk_norm ,
200+ proj_bias = proj_bias ,
195201 attn_drop = attn_drop ,
196202 proj_drop = proj_drop ,
197203 norm_layer = norm_layer ,
@@ -203,6 +209,7 @@ def __init__(
203209 in_features = dim ,
204210 hidden_features = int (dim * mlp_ratio ),
205211 act_layer = act_layer ,
212+ bias = proj_bias ,
206213 drop = proj_drop ,
207214 )
208215 self .norm2 = norm_layer (dim )
@@ -236,6 +243,7 @@ def __init__(
236243 mlp_ratio : float = 4. ,
237244 qkv_bias : bool = False ,
238245 qk_norm : bool = False ,
246+ proj_bias : bool = True ,
239247 proj_drop : float = 0. ,
240248 attn_drop : float = 0. ,
241249 init_values : Optional [float ] = None ,
@@ -266,11 +274,11 @@ def __init__(
266274 self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
267275 self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
268276 self .attn_drop = nn .Dropout (attn_drop )
269- self .attn_out_proj = nn .Linear (dim , dim )
277+ self .attn_out_proj = nn .Linear (dim , dim , bias = proj_bias )
270278
271279 self .mlp_drop = nn .Dropout (proj_drop )
272280 self .mlp_act = act_layer ()
273- self .mlp_out_proj = nn .Linear (mlp_hidden_dim , dim )
281+ self .mlp_out_proj = nn .Linear (mlp_hidden_dim , dim , bias = proj_bias )
274282
275283 self .ls = LayerScale (dim , init_values = init_values ) if init_values is not None else nn .Identity ()
276284 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
@@ -330,6 +338,7 @@ def __init__(
330338 mlp_ratio : float = 4. ,
331339 qkv_bias : bool = False ,
332340 qk_norm : bool = False ,
341+ proj_bias : bool = True ,
333342 init_values : Optional [float ] = None ,
334343 proj_drop : float = 0. ,
335344 attn_drop : float = 0. ,
@@ -350,6 +359,7 @@ def __init__(
350359 num_heads = num_heads ,
351360 qkv_bias = qkv_bias ,
352361 qk_norm = qk_norm ,
362+ proj_bias = proj_bias ,
353363 attn_drop = attn_drop ,
354364 proj_drop = proj_drop ,
355365 norm_layer = norm_layer ,
@@ -363,6 +373,7 @@ def __init__(
363373 dim ,
364374 hidden_features = int (dim * mlp_ratio ),
365375 act_layer = act_layer ,
376+ bias = proj_bias ,
366377 drop = proj_drop ,
367378 )),
368379 ('ls' , LayerScale (dim , init_values = init_values ) if init_values else nn .Identity ()),
@@ -433,6 +444,7 @@ def __init__(
433444 mlp_ratio : float = 4. ,
434445 qkv_bias : bool = True ,
435446 qk_norm : bool = False ,
447+ proj_bias : bool = True ,
436448 init_values : Optional [float ] = None ,
437449 class_token : bool = True ,
438450 pos_embed : str = 'learn' ,
@@ -452,6 +464,7 @@ def __init__(
452464 weight_init : Literal ['skip' , 'jax' , 'jax_nlhb' , 'moco' , '' ] = '' ,
453465 fix_init : bool = False ,
454466 embed_layer : Callable = PatchEmbed ,
467+ embed_norm_layer : Optional [LayerType ] = None ,
455468 norm_layer : Optional [LayerType ] = None ,
456469 act_layer : Optional [LayerType ] = None ,
457470 block_fn : Type [nn .Module ] = Block ,
@@ -483,6 +496,7 @@ def __init__(
483496 weight_init: Weight initialization scheme.
484497 fix_init: Apply weight initialization fix (scaling w/ layer index).
485498 embed_layer: Patch embedding layer.
499+ embed_norm_layer: Normalization layer to use / override in patch embed module.
486500 norm_layer: Normalization layer.
487501 act_layer: MLP activation layer.
488502 block_fn: Transformer block layer.
@@ -493,6 +507,7 @@ def __init__(
493507 assert pos_embed in ('' , 'none' , 'learn' )
494508 use_fc_norm = global_pool in ('avg' , 'avgmax' , 'max' ) if fc_norm is None else fc_norm
495509 norm_layer = get_norm_layer (norm_layer ) or partial (nn .LayerNorm , eps = 1e-6 )
510+ embed_norm_layer = get_norm_layer (embed_norm_layer )
496511 act_layer = get_act_layer (act_layer ) or nn .GELU
497512
498513 self .num_classes = num_classes
@@ -510,6 +525,8 @@ def __init__(
510525 if dynamic_img_size :
511526 # flatten deferred until after pos embed
512527 embed_args .update (dict (strict_img_size = False , output_fmt = 'NHWC' ))
528+ if embed_norm_layer is not None :
529+ embed_args ['norm_layer' ] = embed_norm_layer
513530 self .patch_embed = embed_layer (
514531 img_size = img_size ,
515532 patch_size = patch_size ,
@@ -539,14 +556,15 @@ def __init__(
539556 self .patch_drop = nn .Identity ()
540557 self .norm_pre = norm_layer (embed_dim ) if pre_norm else nn .Identity ()
541558
542- dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )] # stochastic depth decay rule
559+ dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth , device = 'cpu' )] # stochastic depth decay rule
543560 self .blocks = nn .Sequential (* [
544561 block_fn (
545562 dim = embed_dim ,
546563 num_heads = num_heads ,
547564 mlp_ratio = mlp_ratio ,
548565 qkv_bias = qkv_bias ,
549566 qk_norm = qk_norm ,
567+ proj_bias = proj_bias ,
550568 init_values = init_values ,
551569 proj_drop = proj_drop_rate ,
552570 attn_drop = attn_drop_rate ,
@@ -1128,6 +1146,31 @@ def _convert_dinov2(
11281146 return out_dict
11291147
11301148
1149+ def _convert_aimv2 (
1150+ state_dict : Dict [str , torch .Tensor ],
1151+ model : VisionTransformer ,
1152+ ) -> Dict [str , torch .Tensor ]:
1153+ #import re
1154+ out_dict = {}
1155+
1156+ for k , v in state_dict .items ():
1157+ k = k .replace ('norm_1' , 'norm1' )
1158+ k = k .replace ('norm_2' , 'norm2' )
1159+ k = k .replace ('preprocessor.patchifier.' , 'patch_embed.' )
1160+ k = k .replace ('preprocessor.pos_embed' , 'pos_embed' )
1161+ k = k .replace ('trunk.' , '' )
1162+ k = k .replace ('mlp.fc1' , 'mlp.fc1_g' )
1163+ k = k .replace ('mlp.fc3' , 'mlp.fc1_x' )
1164+ k = k .replace ('post_trunk_norm.' , 'norm.' )
1165+ # if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
1166+ # out_dict[k.replace("w12", "fc1")] = v
1167+ # continue
1168+ # elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
1169+ # out_dict[k.replace("w3", "fc2")] = v
1170+ # continue
1171+ out_dict [k ] = v
1172+ return out_dict
1173+
11311174def checkpoint_filter_fn (
11321175 state_dict : Dict [str , torch .Tensor ],
11331176 model : VisionTransformer ,
@@ -1159,6 +1202,8 @@ def checkpoint_filter_fn(
11591202 # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
11601203 out_dict ['head.weight' ] = state_dict ['visual.head.proj.weight' ]
11611204 out_dict ['head.bias' ] = torch .zeros (state_dict ['visual.head.proj.weight' ].shape [0 ])
1205+ elif 'preprocessor.patchifier.proj.weight' in state_dict :
1206+ state_dict = _convert_aimv2 (state_dict , model )
11621207
11631208 if prefix :
11641209 # filter on & remove prefix string from keys
@@ -2119,6 +2164,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
21192164 input_size = (3 , 448 , 448 ), crop_pct = 1.0 , num_classes = 0 ,
21202165 ),
21212166
2167+ 'vit_large_patch14_aimv2_224' : _cfg (
2168+ hf_hub_id = 'apple/aimv2-large-patch14-224' ,
2169+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
2170+ input_size = (3 , 224 , 224 ), crop_pct = 1.0 ,
2171+ num_classes = 0 ),
2172+
21222173 'test_vit.r160_in1k' : _cfg (
21232174 hf_hub_id = 'timm/' ,
21242175 input_size = (3 , 160 , 160 ), crop_pct = 0.95 ),
@@ -3390,6 +3441,21 @@ def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTran
33903441 return model
33913442
33923443
3444+ @register_model
3445+ def vit_large_patch14_aimv2_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
3446+ """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
3447+ """
3448+ rms_norm = partial (RmsNorm , eps = 1e-5 )
3449+ model_args = dict (
3450+ patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , class_token = False , fc_norm = False ,
3451+ mlp_ratio = 2.75 , global_pool = 'avg' , norm_layer = rms_norm , embed_norm_layer = rms_norm , mlp_layer = SwiGLU ,
3452+ qkv_bias = False , proj_bias = False ,
3453+ )
3454+ model = _create_vision_transformer (
3455+ 'vit_large_patch14_aimv2_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
3456+ return model
3457+
3458+
33933459@register_model
33943460def test_vit (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
33953461 """ ViT Test
0 commit comments