Skip to content

Commit 82ae247

Browse files
committed
MambaOut weights on hub, configs finalized
1 parent 7efb60c commit 82ae247

File tree

1 file changed

+78
-23
lines changed

1 file changed

+78
-23
lines changed

timm/models/mambaout.py

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
1616
from ._builder import build_model_with_cfg
1717
from ._manipulate import checkpoint_seq
18-
from ._registry import register_model
18+
from ._registry import register_model, generate_default_cfgs
1919

2020

2121
class Stem(nn.Module):
@@ -435,6 +435,8 @@ def forward(self, x):
435435
def checkpoint_filter_fn(state_dict, model):
436436
if 'model' in state_dict:
437437
state_dict = state_dict['model']
438+
if 'stem.conv1.weight' in state_dict:
439+
return state_dict
438440

439441
import re
440442
out_dict = {}
@@ -458,30 +460,52 @@ def checkpoint_filter_fn(state_dict, model):
458460
def _cfg(url='', **kwargs):
459461
return {
460462
'url': url,
461-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
462-
'crop_pct': 1.0, 'interpolation': 'bicubic',
463+
'num_classes': 1000, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288),
464+
'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic',
463465
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
464466
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
465467
**kwargs
466468
}
467469

468470

469-
default_cfgs = {
470-
'mambaout_femto': _cfg(
471-
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_femto.pth'),
472-
'mambaout_kobe': _cfg(
473-
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_kobe.pth'),
474-
'mambaout_tiny': _cfg(
475-
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_tiny.pth'),
476-
'mambaout_small': _cfg(
477-
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_small.pth'),
478-
'mambaout_base': _cfg(
479-
url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'),
480-
'mambaout_small_rw': _cfg(),
481-
'mambaout_base_slim_rw': _cfg(),
482-
'mambaout_base_plus_rw': _cfg(),
483-
'test_mambaout': _cfg(input_size=(3, 160, 160), pool_size=(5, 5)),
484-
}
471+
default_cfgs = generate_default_cfgs({
472+
# original weights
473+
'mambaout_femto.in1k': _cfg(
474+
hf_hub_id='timm/'),
475+
'mambaout_kobe.in1k': _cfg(
476+
hf_hub_id='timm/'),
477+
'mambaout_tiny.in1k': _cfg(
478+
hf_hub_id='timm/'),
479+
'mambaout_small.in1k': _cfg(
480+
hf_hub_id='timm/'),
481+
'mambaout_base.in1k': _cfg(
482+
hf_hub_id='timm/'),
483+
484+
# timm experiments below
485+
'mambaout_small_rw.sw_e450_in1k': _cfg(
486+
hf_hub_id='timm/',
487+
),
488+
'mambaout_base_short_rw.sw_e500_in1k': _cfg(
489+
hf_hub_id='timm/',
490+
crop_pct=0.95, test_crop_pct=1.0,
491+
),
492+
'mambaout_base_tall_rw.sw_e500_in1k': _cfg(
493+
hf_hub_id='timm/',
494+
crop_pct=0.95, test_crop_pct=1.0,
495+
),
496+
'mambaout_base_wide_rw.sw_e500_in1k': _cfg(
497+
hf_hub_id='timm/',
498+
crop_pct=0.95, test_crop_pct=1.0,
499+
),
500+
'mambaout_base_plus_rw.sw_e150_in12k_ft_in1k': _cfg(
501+
hf_hub_id='timm/',
502+
),
503+
'mambaout_base_plus_rw.sw_e150_in12k': _cfg(
504+
hf_hub_id='timm/',
505+
num_classes=11821,
506+
),
507+
'test_mambaout': _cfg(input_size=(3, 160, 160), test_input_size=(3, 192, 192), pool_size=(5, 5)),
508+
})
485509

486510

487511
def _create_mambaout(variant, pretrained=False, **kwargs):
@@ -538,9 +562,24 @@ def mambaout_small_rw(pretrained=False, **kwargs):
538562

539563

540564
@register_model
541-
def mambaout_base_slim_rw(pretrained=False, **kwargs):
565+
def mambaout_base_short_rw(pretrained=False, **kwargs):
542566
model_args = dict(
543-
depths=(3, 4, 27, 3),
567+
depths=(3, 3, 25, 3),
568+
dims=(128, 256, 512, 768),
569+
expansion_ratio=3.0,
570+
conv_ratio=1.25,
571+
stem_mid_norm=False,
572+
downsample='conv_nf',
573+
ls_init_value=1e-6,
574+
head_fn='norm_mlp',
575+
)
576+
return _create_mambaout('mambaout_base_short_rw', pretrained=pretrained, **dict(model_args, **kwargs))
577+
578+
579+
@register_model
580+
def mambaout_base_tall_rw(pretrained=False, **kwargs):
581+
model_args = dict(
582+
depths=(3, 4, 30, 3),
544583
dims=(128, 256, 512, 768),
545584
expansion_ratio=2.5,
546585
conv_ratio=1.25,
@@ -549,11 +588,11 @@ def mambaout_base_slim_rw(pretrained=False, **kwargs):
549588
ls_init_value=1e-6,
550589
head_fn='norm_mlp',
551590
)
552-
return _create_mambaout('mambaout_base_slim_rw', pretrained=pretrained, **dict(model_args, **kwargs))
591+
return _create_mambaout('mambaout_base_tall_rw', pretrained=pretrained, **dict(model_args, **kwargs))
553592

554593

555594
@register_model
556-
def mambaout_base_plus_rw(pretrained=False, **kwargs):
595+
def mambaout_base_wide_rw(pretrained=False, **kwargs):
557596
model_args = dict(
558597
depths=(3, 4, 27, 3),
559598
dims=(128, 256, 512, 768),
@@ -565,6 +604,22 @@ def mambaout_base_plus_rw(pretrained=False, **kwargs):
565604
act_layer='silu',
566605
head_fn='norm_mlp',
567606
)
607+
return _create_mambaout('mambaout_base_wide_rw', pretrained=pretrained, **dict(model_args, **kwargs))
608+
609+
610+
@register_model
611+
def mambaout_base_plus_rw(pretrained=False, **kwargs):
612+
model_args = dict(
613+
depths=(3, 4, 30, 3),
614+
dims=(128, 256, 512, 768),
615+
expansion_ratio=3.0,
616+
conv_ratio=1.5,
617+
stem_mid_norm=False,
618+
downsample='conv_nf',
619+
ls_init_value=1e-6,
620+
act_layer='silu',
621+
head_fn='norm_mlp',
622+
)
568623
return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))
569624

570625

0 commit comments

Comments
 (0)