1515from timm .layers import trunc_normal_ , DropPath , LayerNorm , LayerScale , ClNormMlpClassifierHead , get_act_layer
1616from ._builder import build_model_with_cfg
1717from ._manipulate import checkpoint_seq
18- from ._registry import register_model
18+ from ._registry import register_model , generate_default_cfgs
1919
2020
2121class Stem (nn .Module ):
@@ -435,6 +435,8 @@ def forward(self, x):
435435def checkpoint_filter_fn (state_dict , model ):
436436 if 'model' in state_dict :
437437 state_dict = state_dict ['model' ]
438+ if 'stem.conv1.weight' in state_dict :
439+ return state_dict
438440
439441 import re
440442 out_dict = {}
@@ -458,30 +460,52 @@ def checkpoint_filter_fn(state_dict, model):
458460def _cfg (url = '' , ** kwargs ):
459461 return {
460462 'url' : url ,
461- 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size ' : (7 , 7 ),
462- 'crop_pct' : 1.0 , 'interpolation' : 'bicubic' ,
463+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'test_input_size ' : (3 , 288 , 288 ),
464+ 'pool_size' : ( 7 , 7 ), ' crop_pct' : 1.0 , 'interpolation' : 'bicubic' ,
463465 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
464466 'first_conv' : 'stem.conv1' , 'classifier' : 'head.fc' ,
465467 ** kwargs
466468 }
467469
468470
469- default_cfgs = {
470- 'mambaout_femto' : _cfg (
471- url = 'https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_femto.pth' ),
472- 'mambaout_kobe' : _cfg (
473- url = 'https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_kobe.pth' ),
474- 'mambaout_tiny' : _cfg (
475- url = 'https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_tiny.pth' ),
476- 'mambaout_small' : _cfg (
477- url = 'https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_small.pth' ),
478- 'mambaout_base' : _cfg (
479- url = 'https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth' ),
480- 'mambaout_small_rw' : _cfg (),
481- 'mambaout_base_slim_rw' : _cfg (),
482- 'mambaout_base_plus_rw' : _cfg (),
483- 'test_mambaout' : _cfg (input_size = (3 , 160 , 160 ), pool_size = (5 , 5 )),
484- }
471+ default_cfgs = generate_default_cfgs ({
472+ # original weights
473+ 'mambaout_femto.in1k' : _cfg (
474+ hf_hub_id = 'timm/' ),
475+ 'mambaout_kobe.in1k' : _cfg (
476+ hf_hub_id = 'timm/' ),
477+ 'mambaout_tiny.in1k' : _cfg (
478+ hf_hub_id = 'timm/' ),
479+ 'mambaout_small.in1k' : _cfg (
480+ hf_hub_id = 'timm/' ),
481+ 'mambaout_base.in1k' : _cfg (
482+ hf_hub_id = 'timm/' ),
483+
484+ # timm experiments below
485+ 'mambaout_small_rw.sw_e450_in1k' : _cfg (
486+ hf_hub_id = 'timm/' ,
487+ ),
488+ 'mambaout_base_short_rw.sw_e500_in1k' : _cfg (
489+ hf_hub_id = 'timm/' ,
490+ crop_pct = 0.95 , test_crop_pct = 1.0 ,
491+ ),
492+ 'mambaout_base_tall_rw.sw_e500_in1k' : _cfg (
493+ hf_hub_id = 'timm/' ,
494+ crop_pct = 0.95 , test_crop_pct = 1.0 ,
495+ ),
496+ 'mambaout_base_wide_rw.sw_e500_in1k' : _cfg (
497+ hf_hub_id = 'timm/' ,
498+ crop_pct = 0.95 , test_crop_pct = 1.0 ,
499+ ),
500+ 'mambaout_base_plus_rw.sw_e150_in12k_ft_in1k' : _cfg (
501+ hf_hub_id = 'timm/' ,
502+ ),
503+ 'mambaout_base_plus_rw.sw_e150_in12k' : _cfg (
504+ hf_hub_id = 'timm/' ,
505+ num_classes = 11821 ,
506+ ),
507+ 'test_mambaout' : _cfg (input_size = (3 , 160 , 160 ), test_input_size = (3 , 192 , 192 ), pool_size = (5 , 5 )),
508+ })
485509
486510
487511def _create_mambaout (variant , pretrained = False , ** kwargs ):
@@ -538,9 +562,24 @@ def mambaout_small_rw(pretrained=False, **kwargs):
538562
539563
540564@register_model
541- def mambaout_base_slim_rw (pretrained = False , ** kwargs ):
565+ def mambaout_base_short_rw (pretrained = False , ** kwargs ):
542566 model_args = dict (
543- depths = (3 , 4 , 27 , 3 ),
567+ depths = (3 , 3 , 25 , 3 ),
568+ dims = (128 , 256 , 512 , 768 ),
569+ expansion_ratio = 3.0 ,
570+ conv_ratio = 1.25 ,
571+ stem_mid_norm = False ,
572+ downsample = 'conv_nf' ,
573+ ls_init_value = 1e-6 ,
574+ head_fn = 'norm_mlp' ,
575+ )
576+ return _create_mambaout ('mambaout_base_short_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
577+
578+
579+ @register_model
580+ def mambaout_base_tall_rw (pretrained = False , ** kwargs ):
581+ model_args = dict (
582+ depths = (3 , 4 , 30 , 3 ),
544583 dims = (128 , 256 , 512 , 768 ),
545584 expansion_ratio = 2.5 ,
546585 conv_ratio = 1.25 ,
@@ -549,11 +588,11 @@ def mambaout_base_slim_rw(pretrained=False, **kwargs):
549588 ls_init_value = 1e-6 ,
550589 head_fn = 'norm_mlp' ,
551590 )
552- return _create_mambaout ('mambaout_base_slim_rw ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
591+ return _create_mambaout ('mambaout_base_tall_rw ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
553592
554593
555594@register_model
556- def mambaout_base_plus_rw (pretrained = False , ** kwargs ):
595+ def mambaout_base_wide_rw (pretrained = False , ** kwargs ):
557596 model_args = dict (
558597 depths = (3 , 4 , 27 , 3 ),
559598 dims = (128 , 256 , 512 , 768 ),
@@ -565,6 +604,22 @@ def mambaout_base_plus_rw(pretrained=False, **kwargs):
565604 act_layer = 'silu' ,
566605 head_fn = 'norm_mlp' ,
567606 )
607+ return _create_mambaout ('mambaout_base_wide_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
608+
609+
610+ @register_model
611+ def mambaout_base_plus_rw (pretrained = False , ** kwargs ):
612+ model_args = dict (
613+ depths = (3 , 4 , 30 , 3 ),
614+ dims = (128 , 256 , 512 , 768 ),
615+ expansion_ratio = 3.0 ,
616+ conv_ratio = 1.5 ,
617+ stem_mid_norm = False ,
618+ downsample = 'conv_nf' ,
619+ ls_init_value = 1e-6 ,
620+ act_layer = 'silu' ,
621+ head_fn = 'norm_mlp' ,
622+ )
568623 return _create_mambaout ('mambaout_base_plus_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
569624
570625
0 commit comments