Skip to content

Commit c2ba229

Browse files
committed
Prep for effcientnetv2_rw_m model weights that started training before official release..
1 parent 22f7c67 commit c2ba229

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

timm/models/efficientnet.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ def _cfg(url='', **kwargs):
162162
'efficientnetv2_rw_s': _cfg(
163163
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth',
164164
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
165+
'efficientnetv2_rw_m': _cfg(
166+
url='',
167+
input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
165168

166169
'efficientnetv2_s': _cfg(
167170
url='',
@@ -173,7 +176,6 @@ def _cfg(url='', **kwargs):
173176
url='',
174177
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
175178

176-
177179
'tf_efficientnet_b0': _cfg(
178180
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
179181
input_size=(3, 224, 224)),
@@ -1461,14 +1463,24 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
14611463

14621464
@register_model
14631465
def efficientnetv2_rw_s(pretrained=False, **kwargs):
1464-
""" EfficientNet-V2 Small.
1466+
""" EfficientNet-V2 Small RW variant.
14651467
NOTE: This is my initial (pre official code release) w/ some differences.
14661468
See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding
14671469
"""
14681470
model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs)
14691471
return model
14701472

14711473

1474+
@register_model
1475+
def efficientnetv2_rw_m(pretrained=False, **kwargs):
1476+
""" EfficientNet-V2 Medium RW variant.
1477+
"""
1478+
model = _gen_efficientnetv2_s(
1479+
'efficientnetv2_rw_m', channel_multiplier=1.2, depth_multiplier=(1.2,) * 4 + (1.6,) * 2, rw=True,
1480+
pretrained=pretrained, **kwargs)
1481+
return model
1482+
1483+
14721484
@register_model
14731485
def efficientnetv2_s(pretrained=False, **kwargs):
14741486
""" EfficientNet-V2 Small. """

timm/models/efficientnet_builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,11 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c
237237

238238
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
239239
arch_args = []
240-
for stack_idx, block_strings in enumerate(arch_def):
240+
if isinstance(depth_multiplier, tuple):
241+
assert len(depth_multiplier) == len(arch_def)
242+
else:
243+
depth_multiplier = (depth_multiplier,) * len(arch_def)
244+
for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)):
241245
assert isinstance(block_strings, list)
242246
stack_args = []
243247
repeats = []
@@ -251,7 +255,7 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_
251255
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
252256
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
253257
else:
254-
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
258+
arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc))
255259
return arch_args
256260

257261

0 commit comments

Comments
 (0)