Skip to content

Commit 5078b28

Browse files
committed
More kwarg handling tweaks, maxvit_base_rw def added
1 parent c0d7388 commit 5078b28

File tree

6 files changed

+117
-56
lines changed

6 files changed

+117
-56
lines changed

timm/models/densenet.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch.jit.annotations import List
1313

1414
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15-
from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
15+
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
1616
from ._builder import build_model_with_cfg
1717
from ._manipulate import MATCH_PREV_GROUP
1818
from ._registry import register_model
@@ -115,8 +115,15 @@ class DenseBlock(nn.ModuleDict):
115115
_version = 2
116116

117117
def __init__(
118-
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d,
119-
drop_rate=0., memory_efficient=False):
118+
self,
119+
num_layers,
120+
num_input_features,
121+
bn_size,
122+
growth_rate,
123+
norm_layer=BatchNormAct2d,
124+
drop_rate=0.,
125+
memory_efficient=False,
126+
):
120127
super(DenseBlock, self).__init__()
121128
for i in range(num_layers):
122129
layer = DenseLayer(
@@ -165,12 +172,25 @@ class DenseNet(nn.Module):
165172
"""
166173

167174
def __init__(
168-
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg',
169-
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0,
170-
memory_efficient=False, aa_stem_only=True):
175+
self,
176+
growth_rate=32,
177+
block_config=(6, 12, 24, 16),
178+
num_classes=1000,
179+
in_chans=3,
180+
global_pool='avg',
181+
bn_size=4,
182+
stem_type='',
183+
act_layer='relu',
184+
norm_layer='batchnorm2d',
185+
aa_layer=None,
186+
drop_rate=0,
187+
memory_efficient=False,
188+
aa_stem_only=True,
189+
):
171190
self.num_classes = num_classes
172191
self.drop_rate = drop_rate
173192
super(DenseNet, self).__init__()
193+
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
174194

175195
# Stem
176196
deep_stem = 'deep' in stem_type # 3x3 deep stem
@@ -226,8 +246,11 @@ def __init__(
226246
dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
227247
current_stride *= 2
228248
trans = DenseTransition(
229-
num_input_features=num_features, num_output_features=num_features // 2,
230-
norm_layer=norm_layer, aa_layer=transition_aa_layer)
249+
num_input_features=num_features,
250+
num_output_features=num_features // 2,
251+
norm_layer=norm_layer,
252+
aa_layer=transition_aa_layer,
253+
)
231254
self.features.add_module(f'transition{i + 1}', trans)
232255
num_features = num_features // 2
233256

@@ -322,8 +345,8 @@ def densenetblur121d(pretrained=False, **kwargs):
322345
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
323346
"""
324347
model = _create_densenet(
325-
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep',
326-
aa_layer=BlurPool2d, **kwargs)
348+
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained,
349+
stem_type='deep', aa_layer=BlurPool2d, **kwargs)
327350
return model
328351

329352

@@ -382,11 +405,9 @@ def densenet264(pretrained=False, **kwargs):
382405
def densenet264d_iabn(pretrained=False, **kwargs):
383406
r"""Densenet-264 model with deep stem and Inplace-ABN
384407
"""
385-
def norm_act_fn(num_features, **kwargs):
386-
return create_norm_act_layer('iabn', num_features, act_layer='leaky_relu', **kwargs)
387408
model = _create_densenet(
388409
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
389-
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
410+
norm_layer='iabn', act_layer='leaky_relu', pretrained=pretrained, **kwargs)
390411
return model
391412

392413

timm/models/dpn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,21 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
178178
class DPN(nn.Module):
179179
def __init__(
180180
self,
181-
num_classes=1000,
182-
in_chans=3,
183-
output_stride=32,
184-
global_pool='avg',
185181
k_sec=(3, 4, 20, 3),
186182
inc_sec=(16, 32, 24, 128),
187183
k_r=96,
188184
groups=32,
185+
num_classes=1000,
186+
in_chans=3,
187+
output_stride=32,
188+
global_pool='avg',
189189
small=False,
190190
num_init_features=64,
191191
b=False,
192192
drop_rate=0.,
193193
norm_layer='batchnorm2d',
194194
act_layer='relu',
195-
fc_act_layer=nn.ELU,
195+
fc_act_layer='elu',
196196
):
197197
super(DPN, self).__init__()
198198
self.num_classes = num_classes

timm/models/maxxvit.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,6 +1680,26 @@ def _tf_cfg():
16801680
init_values=1e-6,
16811681
),
16821682
),
1683+
maxvit_rmlp_base_rw_224=MaxxVitCfg(
1684+
embed_dim=(96, 192, 384, 768),
1685+
depths=(2, 6, 14, 2),
1686+
block_type=('M',) * 4,
1687+
stem_width=(32, 64),
1688+
head_hidden_size=768,
1689+
**_rw_max_cfg(
1690+
rel_pos_type='mlp',
1691+
),
1692+
),
1693+
maxvit_rmlp_base_rw_384=MaxxVitCfg(
1694+
embed_dim=(96, 192, 384, 768),
1695+
depths=(2, 6, 14, 2),
1696+
block_type=('M',) * 4,
1697+
stem_width=(32, 64),
1698+
head_hidden_size=768,
1699+
**_rw_max_cfg(
1700+
rel_pos_type='mlp',
1701+
),
1702+
),
16831703

16841704
maxvit_tiny_pm_256=MaxxVitCfg(
16851705
embed_dim=(64, 128, 256, 512),
@@ -1862,6 +1882,12 @@ def _cfg(url='', **kwargs):
18621882
'maxvit_rmlp_small_rw_256': _cfg(
18631883
url='',
18641884
input_size=(3, 256, 256), pool_size=(8, 8)),
1885+
'maxvit_rmlp_base_rw_2244': _cfg(
1886+
url='',
1887+
),
1888+
'maxvit_rmlp_base_rw_384': _cfg(
1889+
url='',
1890+
input_size=(3, 384, 384), pool_size=(12, 12)),
18651891

18661892
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
18671893

@@ -2091,6 +2117,16 @@ def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
20912117
return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
20922118

20932119

2120+
@register_model
2121+
def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
2122+
return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
2123+
2124+
2125+
@register_model
2126+
def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs):
2127+
return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
2128+
2129+
20942130
@register_model
20952131
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
20962132
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)

timm/models/mobilevit.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,16 @@ def __init__(
266266

267267
self.transformer = nn.Sequential(*[
268268
TransformerBlock(
269-
transformer_dim, mlp_ratio=mlp_ratio, num_heads=num_heads, qkv_bias=True,
270-
attn_drop=attn_drop, drop=drop, drop_path=drop_path_rate,
271-
act_layer=layers.act, norm_layer=transformer_norm_layer)
269+
transformer_dim,
270+
mlp_ratio=mlp_ratio,
271+
num_heads=num_heads,
272+
qkv_bias=True,
273+
attn_drop=attn_drop,
274+
drop=drop,
275+
drop_path=drop_path_rate,
276+
act_layer=layers.act,
277+
norm_layer=transformer_norm_layer,
278+
)
272279
for _ in range(transformer_depth)
273280
])
274281
self.norm = transformer_norm_layer(transformer_dim)

timm/models/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,7 +1298,7 @@ def ecaresnet50d_pruned(pretrained=False, **kwargs):
12981298
model_args = dict(
12991299
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
13001300
block_args=dict(attn_layer='eca'))
1301-
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args)
1301+
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
13021302

13031303

13041304
@register_model
@@ -1340,7 +1340,7 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs):
13401340
model_args = dict(
13411341
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
13421342
block_args=dict(attn_layer='eca'))
1343-
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
1343+
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
13441344

13451345

13461346
@register_model

timm/models/resnetv2.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -746,86 +746,83 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
746746

747747
@register_model
748748
def resnetv2_50(pretrained=False, **kwargs):
749-
return _create_resnetv2(
750-
'resnetv2_50', pretrained=pretrained,
751-
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
749+
model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
750+
return _create_resnetv2('resnetv2_50', pretrained=pretrained, **dict(model_args, **kwargs))
752751

753752

754753
@register_model
755754
def resnetv2_50d(pretrained=False, **kwargs):
756-
return _create_resnetv2(
757-
'resnetv2_50d', pretrained=pretrained,
755+
model_args = dict(
758756
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
759-
stem_type='deep', avg_down=True, **kwargs)
757+
stem_type='deep', avg_down=True)
758+
return _create_resnetv2('resnetv2_50d', pretrained=pretrained, **dict(model_args, **kwargs))
760759

761760

762761
@register_model
763762
def resnetv2_50t(pretrained=False, **kwargs):
764-
return _create_resnetv2(
765-
'resnetv2_50t', pretrained=pretrained,
763+
model_args = dict(
766764
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
767-
stem_type='tiered', avg_down=True, **kwargs)
765+
stem_type='tiered', avg_down=True)
766+
return _create_resnetv2('resnetv2_50t', pretrained=pretrained, **dict(model_args, **kwargs))
768767

769768

770769
@register_model
771770
def resnetv2_101(pretrained=False, **kwargs):
772-
return _create_resnetv2(
773-
'resnetv2_101', pretrained=pretrained,
774-
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
771+
model_args = dict(layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
772+
return _create_resnetv2('resnetv2_101', pretrained=pretrained, **dict(model_args, **kwargs))
775773

776774

777775
@register_model
778776
def resnetv2_101d(pretrained=False, **kwargs):
779-
return _create_resnetv2(
780-
'resnetv2_101d', pretrained=pretrained,
777+
model_args = dict(
781778
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
782-
stem_type='deep', avg_down=True, **kwargs)
779+
stem_type='deep', avg_down=True)
780+
return _create_resnetv2('resnetv2_101d', pretrained=pretrained, **dict(model_args, **kwargs))
783781

784782

785783
@register_model
786784
def resnetv2_152(pretrained=False, **kwargs):
787-
return _create_resnetv2(
788-
'resnetv2_152', pretrained=pretrained,
789-
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
785+
model_args = dict(layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
786+
return _create_resnetv2('resnetv2_152', pretrained=pretrained, **dict(model_args, **kwargs))
790787

791788

792789
@register_model
793790
def resnetv2_152d(pretrained=False, **kwargs):
794-
return _create_resnetv2(
795-
'resnetv2_152d', pretrained=pretrained,
791+
model_args = dict(
796792
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
797-
stem_type='deep', avg_down=True, **kwargs)
793+
stem_type='deep', avg_down=True)
794+
return _create_resnetv2('resnetv2_152d', pretrained=pretrained, **dict(model_args, **kwargs))
798795

799796

800797
# Experimental configs (may change / be removed)
801798

802799
@register_model
803800
def resnetv2_50d_gn(pretrained=False, **kwargs):
804-
return _create_resnetv2(
805-
'resnetv2_50d_gn', pretrained=pretrained,
801+
model_args = dict(
806802
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct,
807-
stem_type='deep', avg_down=True, **kwargs)
803+
stem_type='deep', avg_down=True)
804+
return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs))
808805

809806

810807
@register_model
811808
def resnetv2_50d_evob(pretrained=False, **kwargs):
812-
return _create_resnetv2(
813-
'resnetv2_50d_evob', pretrained=pretrained,
809+
model_args = dict(
814810
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
815-
stem_type='deep', avg_down=True, zero_init_last=True, **kwargs)
811+
stem_type='deep', avg_down=True, zero_init_last=True)
812+
return _create_resnetv2('resnetv2_50d_evob', pretrained=pretrained, **dict(model_args, **kwargs))
816813

817814

818815
@register_model
819816
def resnetv2_50d_evos(pretrained=False, **kwargs):
820-
return _create_resnetv2(
821-
'resnetv2_50d_evos', pretrained=pretrained,
817+
model_args = dict(
822818
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
823-
stem_type='deep', avg_down=True, **kwargs)
819+
stem_type='deep', avg_down=True)
820+
return _create_resnetv2('resnetv2_50d_evos', pretrained=pretrained, **dict(model_args, **kwargs))
824821

825822

826823
@register_model
827824
def resnetv2_50d_frn(pretrained=False, **kwargs):
828-
return _create_resnetv2(
829-
'resnetv2_50d_frn', pretrained=pretrained,
825+
model_args = dict(
830826
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
831-
stem_type='deep', avg_down=True, **kwargs)
827+
stem_type='deep', avg_down=True)
828+
return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)