Skip to content

Commit 8ec554b

Browse files
authored
Merge pull request #136 from yoniaflalo/adding_effnet_pruned
added efficientnet pruned weights
2 parents a4d20a1 + 9c15d57 commit 8ec554b

File tree

4 files changed

+53
-2
lines changed

4 files changed

+53
-2
lines changed

timm/models/efficientnet.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .efficientnet_builder import *
2828
from .feature_hooks import FeatureHooks
2929
from .registry import register_model
30-
from .helpers import load_pretrained
30+
from .helpers import load_pretrained, adapt_model_from_file
3131
from .layers import SelectAdaptivePool2d
3232
from timm.models.layers import create_conv2d
3333
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@@ -131,6 +131,16 @@ def _cfg(url='', **kwargs):
131131
'efficientnet_lite4': _cfg(
132132
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
133133

134+
'efficientnet_b1_pruned': _cfg(
135+
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth',
136+
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
137+
'efficientnet_b2_pruned': _cfg(
138+
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth',
139+
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
140+
'efficientnet_b3_pruned': _cfg(
141+
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
142+
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
143+
134144
'tf_efficientnet_b0': _cfg(
135145
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
136146
input_size=(3, 224, 224)),
@@ -482,9 +492,11 @@ def _create_model(model_kwargs, default_cfg, pretrained=False):
482492
else:
483493
load_strict = True
484494
model_class = EfficientNet
485-
495+
variant = model_kwargs.pop('variant', '')
486496
model = model_class(**model_kwargs)
487497
model.default_cfg = default_cfg
498+
if '_pruned' in variant:
499+
model = adapt_model_from_file(model, variant)
488500
if pretrained:
489501
load_pretrained(
490502
model,
@@ -730,6 +742,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
730742
channel_multiplier=channel_multiplier,
731743
act_layer=Swish,
732744
norm_kwargs=resolve_bn_args(kwargs),
745+
variant=variant,
733746
**kwargs,
734747
)
735748
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@@ -1229,6 +1242,41 @@ def efficientnet_lite4(pretrained=False, **kwargs):
12291242
return model
12301243

12311244

1245+
1246+
1247+
@register_model
1248+
def efficientnet_b1_pruned(pretrained=False, **kwargs):
1249+
""" EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
1250+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1251+
kwargs['pad_type'] = 'same'
1252+
variant = 'efficientnet_b1_pruned'
1253+
model = _gen_efficientnet(
1254+
variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
1255+
return model
1256+
1257+
1258+
@register_model
1259+
def efficientnet_b2_pruned(pretrained=False, **kwargs):
1260+
""" EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
1261+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1262+
kwargs['pad_type'] = 'same'
1263+
model = _gen_efficientnet(
1264+
'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
1265+
return model
1266+
1267+
1268+
@register_model
1269+
def efficientnet_b3_pruned(pretrained=False, **kwargs):
1270+
""" EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
1271+
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
1272+
kwargs['pad_type'] = 'same'
1273+
model = _gen_efficientnet(
1274+
'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
1275+
return model
1276+
1277+
1278+
1279+
12321280
@register_model
12331281
def tf_efficientnet_b0(pretrained=False, **kwargs):
12341282
""" EfficientNet-B0. Tensorflow compatible variant """

0 commit comments

Comments
 (0)