@@ -1150,25 +1150,25 @@ def _convert_aimv2(
11501150 state_dict : Dict [str , torch .Tensor ],
11511151 model : VisionTransformer ,
11521152) -> Dict [str , torch .Tensor ]:
1153- #import re
11541153 out_dict = {}
1155-
11561154 for k , v in state_dict .items ():
11571155 k = k .replace ('norm_1' , 'norm1' )
11581156 k = k .replace ('norm_2' , 'norm2' )
11591157 k = k .replace ('preprocessor.patchifier.' , 'patch_embed.' )
11601158 k = k .replace ('preprocessor.pos_embed' , 'pos_embed' )
11611159 k = k .replace ('trunk.' , '' )
1162- k = k .replace ('mlp.fc1' , 'mlp.fc1_g' )
1163- k = k .replace ('mlp.fc3' , 'mlp.fc1_x' )
11641160 k = k .replace ('post_trunk_norm.' , 'norm.' )
1165- # if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
1166- # out_dict[k.replace("w12", "fc1")] = v
1167- # continue
1168- # elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
1169- # out_dict[k.replace("w3", "fc2")] = v
1170- # continue
1161+
1162+ if 'mlp.fc1' in k :
1163+ if k in out_dict :
1164+ v = torch .cat ([v , out_dict [k ]], dim = 0 )
1165+ elif 'mlp.fc3' in k :
1166+ k = k .replace ('mlp.fc3' , 'mlp.fc1' )
1167+ if k in out_dict :
1168+ v = torch .cat ([out_dict [k ], v ], dim = 0 )
1169+
11711170 out_dict [k ] = v
1171+
11721172 return out_dict
11731173
11741174def checkpoint_filter_fn (
@@ -3448,8 +3448,8 @@ def vit_large_patch14_aimv2_224(pretrained: bool = False, **kwargs) -> VisionTra
34483448 rms_norm = partial (RmsNorm , eps = 1e-5 )
34493449 model_args = dict (
34503450 patch_size = 14 , embed_dim = 1024 , depth = 24 , num_heads = 16 , class_token = False , fc_norm = False ,
3451- mlp_ratio = 2.75 , global_pool = 'avg' , norm_layer = rms_norm , embed_norm_layer = rms_norm , mlp_layer = SwiGLU ,
3452- qkv_bias = False , proj_bias = False ,
3451+ mlp_ratio = 5.5 , global_pool = 'avg' , norm_layer = rms_norm , embed_norm_layer = rms_norm , mlp_layer = SwiGLUPacked ,
3452+ qkv_bias = False , proj_bias = False , act_layer = 'silu'
34533453 )
34543454 model = _create_vision_transformer (
34553455 'vit_large_patch14_aimv2_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments