Skip to content

Commit c0d7388

Browse files
committed
Improving kwarg merging in more models
1 parent 94a9159 commit c0d7388

File tree

5 files changed

+292
-236
lines changed

5 files changed

+292
-236
lines changed

timm/models/dpn.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch.nn.functional as F
1616

1717
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18-
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
18+
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer
1919
from ._builder import build_model_with_cfg
2020
from ._registry import register_model
2121

@@ -33,6 +33,7 @@ def _cfg(url='', **kwargs):
3333

3434

3535
default_cfgs = {
36+
'dpn48b': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
3637
'dpn68': _cfg(
3738
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
3839
'dpn68b': _cfg(
@@ -82,7 +83,16 @@ def forward(self, x):
8283

8384
class DualPathBlock(nn.Module):
8485
def __init__(
85-
self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
86+
self,
87+
in_chs,
88+
num_1x1_a,
89+
num_3x3_b,
90+
num_1x1_c,
91+
inc,
92+
groups,
93+
block_type='normal',
94+
b=False,
95+
):
8696
super(DualPathBlock, self).__init__()
8797
self.num_1x1_c = num_1x1_c
8898
self.inc = inc
@@ -167,16 +177,31 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
167177

168178
class DPN(nn.Module):
169179
def __init__(
170-
self, small=False, num_init_features=64, k_r=96, groups=32, global_pool='avg',
171-
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32,
172-
num_classes=1000, in_chans=3, drop_rate=0., fc_act_layer=nn.ELU):
180+
self,
181+
num_classes=1000,
182+
in_chans=3,
183+
output_stride=32,
184+
global_pool='avg',
185+
k_sec=(3, 4, 20, 3),
186+
inc_sec=(16, 32, 24, 128),
187+
k_r=96,
188+
groups=32,
189+
small=False,
190+
num_init_features=64,
191+
b=False,
192+
drop_rate=0.,
193+
norm_layer='batchnorm2d',
194+
act_layer='relu',
195+
fc_act_layer=nn.ELU,
196+
):
173197
super(DPN, self).__init__()
174198
self.num_classes = num_classes
175199
self.drop_rate = drop_rate
176200
self.b = b
177201
assert output_stride == 32 # FIXME look into dilation support
178-
norm_layer = partial(BatchNormAct2d, eps=.001)
179-
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act_layer, inplace=False)
202+
203+
norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=act_layer), eps=.001)
204+
fc_norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=fc_act_layer), eps=.001, inplace=False)
180205
bw_factor = 1 if small else 4
181206
blocks = OrderedDict()
182207

@@ -291,49 +316,57 @@ def _create_dpn(variant, pretrained=False, **kwargs):
291316
**kwargs)
292317

293318

319+
@register_model
320+
def dpn48b(pretrained=False, **kwargs):
321+
model_kwargs = dict(
322+
small=True, num_init_features=10, k_r=128, groups=32,
323+
b=True, k_sec=(3, 4, 6, 3), inc_sec=(16, 32, 32, 64), act_layer='silu')
324+
return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
325+
326+
294327
@register_model
295328
def dpn68(pretrained=False, **kwargs):
296329
model_kwargs = dict(
297330
small=True, num_init_features=10, k_r=128, groups=32,
298-
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
299-
return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs)
331+
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
332+
return _create_dpn('dpn68', pretrained=pretrained, **dict(model_kwargs, **kwargs))
300333

301334

302335
@register_model
303336
def dpn68b(pretrained=False, **kwargs):
304337
model_kwargs = dict(
305338
small=True, num_init_features=10, k_r=128, groups=32,
306-
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
307-
return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs)
339+
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
340+
return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
308341

309342

310343
@register_model
311344
def dpn92(pretrained=False, **kwargs):
312345
model_kwargs = dict(
313346
num_init_features=64, k_r=96, groups=32,
314-
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs)
315-
return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs)
347+
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128))
348+
return _create_dpn('dpn92', pretrained=pretrained, **dict(model_kwargs, **kwargs))
316349

317350

318351
@register_model
319352
def dpn98(pretrained=False, **kwargs):
320353
model_kwargs = dict(
321354
num_init_features=96, k_r=160, groups=40,
322-
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs)
323-
return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs)
355+
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128))
356+
return _create_dpn('dpn98', pretrained=pretrained, **dict(model_kwargs, **kwargs))
324357

325358

326359
@register_model
327360
def dpn131(pretrained=False, **kwargs):
328361
model_kwargs = dict(
329362
num_init_features=128, k_r=160, groups=40,
330-
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs)
331-
return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs)
363+
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128))
364+
return _create_dpn('dpn131', pretrained=pretrained, **dict(model_kwargs, **kwargs))
332365

333366

334367
@register_model
335368
def dpn107(pretrained=False, **kwargs):
336369
model_kwargs = dict(
337370
num_init_features=128, k_r=200, groups=50,
338-
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs)
339-
return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs)
371+
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128))
372+
return _create_dpn('dpn107', pretrained=pretrained, **dict(model_kwargs, **kwargs))

timm/models/maxxvit.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,26 @@ def forward(self, x, pre_logits: bool = False):
11161116
return x
11171117

11181118

1119+
def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
1120+
transformer_kwargs = {}
1121+
conv_kwargs = {}
1122+
base_kwargs = {}
1123+
for k, v in kwargs.items():
1124+
if k.startswith('transformer_'):
1125+
transformer_kwargs[k.replace('transformer_', '')] = v
1126+
elif k.startswith('conv_'):
1127+
conv_kwargs[k.replace('conv_', '')] = v
1128+
else:
1129+
base_kwargs[k] = v
1130+
cfg = replace(
1131+
cfg,
1132+
transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
1133+
conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
1134+
**base_kwargs
1135+
)
1136+
return cfg
1137+
1138+
11191139
class MaxxVit(nn.Module):
11201140
""" CoaTNet + MaxVit base model.
11211141
@@ -1130,10 +1150,13 @@ def __init__(
11301150
num_classes: int = 1000,
11311151
global_pool: str = 'avg',
11321152
drop_rate: float = 0.,
1133-
drop_path_rate: float = 0.
1153+
drop_path_rate: float = 0.,
1154+
**kwargs,
11341155
):
11351156
super().__init__()
11361157
img_size = to_2tuple(img_size)
1158+
if kwargs:
1159+
cfg = _overlay_kwargs(cfg, **kwargs)
11371160
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
11381161
self.num_classes = num_classes
11391162
self.global_pool = global_pool

timm/models/res2net.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def res2net50_26w_4s(pretrained=False, **kwargs):
156156
pretrained (bool): If True, returns a model pre-trained on ImageNet
157157
"""
158158
model_args = dict(
159-
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs)
160-
return _create_res2net('res2net50_26w_4s', pretrained, **model_args)
159+
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4))
160+
return _create_res2net('res2net50_26w_4s', pretrained, **dict(model_args, **kwargs))
161161

162162

163163
@register_model
@@ -167,8 +167,8 @@ def res2net101_26w_4s(pretrained=False, **kwargs):
167167
pretrained (bool): If True, returns a model pre-trained on ImageNet
168168
"""
169169
model_args = dict(
170-
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs)
171-
return _create_res2net('res2net101_26w_4s', pretrained, **model_args)
170+
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4))
171+
return _create_res2net('res2net101_26w_4s', pretrained, **dict(model_args, **kwargs))
172172

173173

174174
@register_model
@@ -178,8 +178,8 @@ def res2net50_26w_6s(pretrained=False, **kwargs):
178178
pretrained (bool): If True, returns a model pre-trained on ImageNet
179179
"""
180180
model_args = dict(
181-
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs)
182-
return _create_res2net('res2net50_26w_6s', pretrained, **model_args)
181+
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6))
182+
return _create_res2net('res2net50_26w_6s', pretrained, **dict(model_args, **kwargs))
183183

184184

185185
@register_model
@@ -189,8 +189,8 @@ def res2net50_26w_8s(pretrained=False, **kwargs):
189189
pretrained (bool): If True, returns a model pre-trained on ImageNet
190190
"""
191191
model_args = dict(
192-
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs)
193-
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
192+
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8))
193+
return _create_res2net('res2net50_26w_8s', pretrained, **dict(model_args, **kwargs))
194194

195195

196196
@register_model
@@ -200,8 +200,8 @@ def res2net50_48w_2s(pretrained=False, **kwargs):
200200
pretrained (bool): If True, returns a model pre-trained on ImageNet
201201
"""
202202
model_args = dict(
203-
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs)
204-
return _create_res2net('res2net50_48w_2s', pretrained, **model_args)
203+
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2))
204+
return _create_res2net('res2net50_48w_2s', pretrained, **dict(model_args, **kwargs))
205205

206206

207207
@register_model
@@ -211,8 +211,8 @@ def res2net50_14w_8s(pretrained=False, **kwargs):
211211
pretrained (bool): If True, returns a model pre-trained on ImageNet
212212
"""
213213
model_args = dict(
214-
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs)
215-
return _create_res2net('res2net50_14w_8s', pretrained, **model_args)
214+
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8))
215+
return _create_res2net('res2net50_14w_8s', pretrained, **dict(model_args, **kwargs))
216216

217217

218218
@register_model
@@ -222,5 +222,5 @@ def res2next50(pretrained=False, **kwargs):
222222
pretrained (bool): If True, returns a model pre-trained on ImageNet
223223
"""
224224
model_args = dict(
225-
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs)
226-
return _create_res2net('res2next50', pretrained, **model_args)
225+
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4))
226+
return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs))

timm/models/resnest.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def resnest14d(pretrained=False, **kwargs):
163163
model_kwargs = dict(
164164
block=ResNestBottleneck, layers=[1, 1, 1, 1],
165165
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
166-
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
167-
return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs)
166+
block_args=dict(radix=2, avd=True, avd_first=False))
167+
return _create_resnest('resnest14d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
168168

169169

170170
@register_model
@@ -174,8 +174,8 @@ def resnest26d(pretrained=False, **kwargs):
174174
model_kwargs = dict(
175175
block=ResNestBottleneck, layers=[2, 2, 2, 2],
176176
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
177-
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
178-
return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs)
177+
block_args=dict(radix=2, avd=True, avd_first=False))
178+
return _create_resnest('resnest26d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
179179

180180

181181
@register_model
@@ -186,8 +186,8 @@ def resnest50d(pretrained=False, **kwargs):
186186
model_kwargs = dict(
187187
block=ResNestBottleneck, layers=[3, 4, 6, 3],
188188
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
189-
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
190-
return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs)
189+
block_args=dict(radix=2, avd=True, avd_first=False))
190+
return _create_resnest('resnest50d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
191191

192192

193193
@register_model
@@ -198,8 +198,8 @@ def resnest101e(pretrained=False, **kwargs):
198198
model_kwargs = dict(
199199
block=ResNestBottleneck, layers=[3, 4, 23, 3],
200200
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
201-
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
202-
return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs)
201+
block_args=dict(radix=2, avd=True, avd_first=False))
202+
return _create_resnest('resnest101e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
203203

204204

205205
@register_model
@@ -210,8 +210,8 @@ def resnest200e(pretrained=False, **kwargs):
210210
model_kwargs = dict(
211211
block=ResNestBottleneck, layers=[3, 24, 36, 3],
212212
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
213-
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
214-
return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs)
213+
block_args=dict(radix=2, avd=True, avd_first=False))
214+
return _create_resnest('resnest200e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
215215

216216

217217
@register_model
@@ -222,8 +222,8 @@ def resnest269e(pretrained=False, **kwargs):
222222
model_kwargs = dict(
223223
block=ResNestBottleneck, layers=[3, 30, 48, 8],
224224
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
225-
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
226-
return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs)
225+
block_args=dict(radix=2, avd=True, avd_first=False))
226+
return _create_resnest('resnest269e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
227227

228228

229229
@register_model
@@ -233,8 +233,8 @@ def resnest50d_4s2x40d(pretrained=False, **kwargs):
233233
model_kwargs = dict(
234234
block=ResNestBottleneck, layers=[3, 4, 6, 3],
235235
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
236-
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
237-
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs)
236+
block_args=dict(radix=4, avd=True, avd_first=True))
237+
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
238238

239239

240240
@register_model
@@ -244,5 +244,5 @@ def resnest50d_1s4x24d(pretrained=False, **kwargs):
244244
model_kwargs = dict(
245245
block=ResNestBottleneck, layers=[3, 4, 6, 3],
246246
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
247-
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
248-
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs)
247+
block_args=dict(radix=1, avd=True, avd_first=True))
248+
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **dict(model_kwargs, **kwargs))

0 commit comments

Comments
 (0)