1212import torch
1313import torch .nn as nn
1414
15- from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
16- from timm .layers import DropPath , trunc_normal_ , create_conv2d , ConvNormAct , SqueezeExcite , use_fused_attn , \
17- ClassifierHead
15+ from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
16+ from timm .layers import (
17+ DropPath ,
18+ trunc_normal_ ,
19+ create_conv2d ,
20+ ConvNormAct ,
21+ SqueezeExcite ,
22+ use_fused_attn ,
23+ ClassifierHead ,
24+ LayerNorm2d ,
25+ )
1826from ._builder import build_model_with_cfg
1927from ._features import feature_take_indices
2028from ._manipulate import checkpoint_seq
@@ -427,7 +435,8 @@ def convolutional_stem(
427435 in_chs : int ,
428436 out_chs : int ,
429437 act_layer : Type [nn .Module ] = nn .GELU ,
430- inference_mode : bool = False
438+ inference_mode : bool = False ,
439+ use_scale_branch : bool = True ,
431440) -> nn .Sequential :
432441 """Build convolutional stem with MobileOne blocks.
433442
@@ -447,6 +456,7 @@ def convolutional_stem(
447456 stride = 2 ,
448457 act_layer = act_layer ,
449458 inference_mode = inference_mode ,
459+ use_scale_branch = use_scale_branch ,
450460 ),
451461 MobileOneBlock (
452462 in_chs = out_chs ,
@@ -456,6 +466,7 @@ def convolutional_stem(
456466 group_size = 1 ,
457467 act_layer = act_layer ,
458468 inference_mode = inference_mode ,
469+ use_scale_branch = use_scale_branch ,
459470 ),
460471 MobileOneBlock (
461472 in_chs = out_chs ,
@@ -464,6 +475,7 @@ def convolutional_stem(
464475 stride = 1 ,
465476 act_layer = act_layer ,
466477 inference_mode = inference_mode ,
478+ use_scale_branch = use_scale_branch ,
467479 ),
468480 )
469481
@@ -1118,6 +1130,7 @@ def __init__(
11181130 drop_path_rate : float = 0.0 ,
11191131 layer_scale_init_value : float = 1e-5 ,
11201132 lkc_use_act : bool = False ,
1133+ stem_use_scale_branch : bool = True ,
11211134 fork_feat : bool = False ,
11221135 cls_ratio : float = 2.0 ,
11231136 global_pool : str = 'avg' ,
@@ -1137,6 +1150,7 @@ def __init__(
11371150 embed_dims [0 ],
11381151 act_layer ,
11391152 inference_mode ,
1153+ use_scale_branch = stem_use_scale_branch ,
11401154 )
11411155
11421156 # Build the main stages of the network architecture
@@ -1412,6 +1426,35 @@ def _cfg(url="", **kwargs):
14121426 num_classes = 512 , # CLIP proj dim
14131427 mean = (0. , 0. , 0. ), std = (1. , 1. , 1. )
14141428 ),
1429+
1430+ "fastvit_mci0.apple_mclip2_dfndr2b" : _cfg (
1431+ hf_hub_id = 'timm/' ,
1432+ crop_pct = 1.0 ,
1433+ num_classes = 512 , # CLIP proj dim
1434+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
1435+ license = 'apple-amlr'
1436+ ),
1437+ "fastvit_mci2.apple_mclip2_dfndr2b" : _cfg (
1438+ hf_hub_id = 'timm/' ,
1439+ crop_pct = 0.95 ,
1440+ num_classes = 512 , # CLIP proj dim
1441+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
1442+ license = 'apple-amlr'
1443+ ),
1444+ "fastvit_mci3.apple_mclip2_dfndr2b" : _cfg (
1445+ hf_hub_id = 'timm/' ,
1446+ crop_pct = 0.95 ,
1447+ num_classes = 768 , # CLIP proj dim
1448+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
1449+ license = 'apple-amlr'
1450+ ),
1451+ "fastvit_mci4.apple_mclip2_dfndr2b" : _cfg (
1452+ hf_hub_id = 'timm/' ,
1453+ crop_pct = 0.95 ,
1454+ num_classes = 768 , # CLIP proj dim
1455+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
1456+ license = 'apple-amlr'
1457+ ),
14151458})
14161459
14171460
@@ -1420,6 +1463,9 @@ def checkpoint_filter_fn(state_dict, model):
14201463 if 'stem.0.conv_kxk.0.conv.weight' in state_dict :
14211464 return state_dict # non-original checkpoint, no remapping needed
14221465
1466+ if 'module.visual.trunk.stem.0.conv_kxk.0.conv.weight' in state_dict :
1467+ return {k .replace ('module.visual.trunk.' , '' ): v for k , v in state_dict .items () if k .startswith ('module.visual.trunk' )}
1468+
14231469 state_dict = state_dict .get ('state_dict' , state_dict )
14241470 if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict :
14251471 # remap MobileCLIP checkpoints
@@ -1632,3 +1678,54 @@ def fastvit_mci2(pretrained=False, **kwargs):
16321678 lkc_use_act = True ,
16331679 )
16341680 return _create_fastvit ('fastvit_mci2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1681+
1682+
1683+ @register_model
1684+ def fastvit_mci3 (pretrained = False , ** kwargs ):
1685+ """Instantiate L model variant."""
1686+ model_args = dict (
1687+ layers = (2 , 12 , 24 , 4 , 2 ),
1688+ embed_dims = (96 , 192 , 384 , 768 , 1536 ),
1689+ mlp_ratios = (4 , 4 , 4 , 4 , 4 ),
1690+ se_downsamples = (False , False , False , False , False ),
1691+ downsamples = (False , True , True , True , True ),
1692+ pos_embs = (
1693+ None ,
1694+ None ,
1695+ None ,
1696+ partial (RepConditionalPosEnc , spatial_shape = (7 , 7 )),
1697+ partial (RepConditionalPosEnc , spatial_shape = (7 , 7 ))
1698+ ),
1699+ token_mixers = ("repmixer" , "repmixer" , "repmixer" , "attention" , "attention" ),
1700+ lkc_use_act = True ,
1701+ norm_layer = partial (LayerNorm2d , eps = 1e-5 ),
1702+ stem_use_scale_branch = False ,
1703+ )
1704+ model = _create_fastvit ('fastvit_mci3' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1705+ return model
1706+
1707+
1708+ @register_model
1709+ def fastvit_mci4 (pretrained = False , ** kwargs ):
1710+ """Instantiate XL model variant."""
1711+ model_args = dict (
1712+ layers = (2 , 12 , 24 , 4 , 4 ),
1713+ embed_dims = (128 , 256 , 512 , 1024 , 2048 ),
1714+ mlp_ratios = (4 , 4 , 4 , 4 , 4 ),
1715+ se_downsamples = (False , False , False , False , False ),
1716+ downsamples = (False , True , True , True , True ),
1717+ pos_embs = (
1718+ None ,
1719+ None ,
1720+ None ,
1721+ partial (RepConditionalPosEnc , spatial_shape = (7 , 7 )),
1722+ partial (RepConditionalPosEnc , spatial_shape = (7 , 7 ))
1723+ ),
1724+ token_mixers = ("repmixer" , "repmixer" , "repmixer" , "attention" , "attention" ),
1725+ lkc_use_act = True ,
1726+ norm_layer = partial (LayerNorm2d , eps = 1e-5 ),
1727+ stem_use_scale_branch = False ,
1728+ )
1729+
1730+ model = _create_fastvit ('fastvit_mci4' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1731+ return model
0 commit comments