|
27 | 27 | from .efficientnet_builder import * |
28 | 28 | from .feature_hooks import FeatureHooks |
29 | 29 | from .registry import register_model |
30 | | -from .helpers import load_pretrained |
| 30 | +from .helpers import load_pretrained, adapt_model_from_file |
31 | 31 | from .layers import SelectAdaptivePool2d |
32 | 32 | from timm.models.layers import create_conv2d |
33 | 33 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD |
@@ -131,6 +131,16 @@ def _cfg(url='', **kwargs): |
131 | 131 | 'efficientnet_lite4': _cfg( |
132 | 132 | url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), |
133 | 133 |
|
| 134 | + 'efficientnet_b1_pruned': _cfg( |
| 135 | + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth', |
| 136 | + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), |
| 137 | + 'efficientnet_b2_pruned': _cfg( |
| 138 | + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth', |
| 139 | + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), |
| 140 | + 'efficientnet_b3_pruned': _cfg( |
| 141 | + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth', |
| 142 | + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), |
| 143 | + |
134 | 144 | 'tf_efficientnet_b0': _cfg( |
135 | 145 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', |
136 | 146 | input_size=(3, 224, 224)), |
@@ -482,9 +492,11 @@ def _create_model(model_kwargs, default_cfg, pretrained=False): |
482 | 492 | else: |
483 | 493 | load_strict = True |
484 | 494 | model_class = EfficientNet |
485 | | - |
| 495 | + variant = model_kwargs.pop('variant', '') |
486 | 496 | model = model_class(**model_kwargs) |
487 | 497 | model.default_cfg = default_cfg |
| 498 | + if '_pruned' in variant: |
| 499 | + model = adapt_model_from_file(model, variant) |
488 | 500 | if pretrained: |
489 | 501 | load_pretrained( |
490 | 502 | model, |
@@ -730,6 +742,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre |
730 | 742 | channel_multiplier=channel_multiplier, |
731 | 743 | act_layer=Swish, |
732 | 744 | norm_kwargs=resolve_bn_args(kwargs), |
| 745 | + variant=variant, |
733 | 746 | **kwargs, |
734 | 747 | ) |
735 | 748 | model = _create_model(model_kwargs, default_cfgs[variant], pretrained) |
@@ -1229,6 +1242,41 @@ def efficientnet_lite4(pretrained=False, **kwargs): |
1229 | 1242 | return model |
1230 | 1243 |
|
1231 | 1244 |
|
| 1245 | + |
| 1246 | + |
| 1247 | +@register_model |
| 1248 | +def efficientnet_b1_pruned(pretrained=False, **kwargs): |
| 1249 | + """ EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ |
| 1250 | + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT |
| 1251 | + kwargs['pad_type'] = 'same' |
| 1252 | + variant = 'efficientnet_b1_pruned' |
| 1253 | + model = _gen_efficientnet( |
| 1254 | + variant, channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) |
| 1255 | + return model |
| 1256 | + |
| 1257 | + |
| 1258 | +@register_model |
| 1259 | +def efficientnet_b2_pruned(pretrained=False, **kwargs): |
| 1260 | + """ EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ |
| 1261 | + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT |
| 1262 | + kwargs['pad_type'] = 'same' |
| 1263 | + model = _gen_efficientnet( |
| 1264 | + 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) |
| 1265 | + return model |
| 1266 | + |
| 1267 | + |
| 1268 | +@register_model |
| 1269 | +def efficientnet_b3_pruned(pretrained=False, **kwargs): |
| 1270 | + """ EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ |
| 1271 | + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT |
| 1272 | + kwargs['pad_type'] = 'same' |
| 1273 | + model = _gen_efficientnet( |
| 1274 | + 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) |
| 1275 | + return model |
| 1276 | + |
| 1277 | + |
| 1278 | + |
| 1279 | + |
1232 | 1280 | @register_model |
1233 | 1281 | def tf_efficientnet_b0(pretrained=False, **kwargs): |
1234 | 1282 | """ EfficientNet-B0. Tensorflow compatible variant """ |
|
0 commit comments