Skip to content

Commit c241081

Browse files
authored
Merge pull request #1850 from huggingface/effnet_improve_features_only
Support other features only modes for EfficientNet. Fix #1848 fix #1849
2 parents f9a24fa + 47517db commit c241081

File tree

7 files changed

+51
-34
lines changed

7 files changed

+51
-34
lines changed

timm/models/_efficientnet_blocks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def feature_info(self, location):
7575
if location == 'expansion': # output of conv after act, same as block coutput
7676
return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels)
7777
else: # location == 'bottleneck', block output
78-
return dict(module='', hook_type='', num_chs=self.conv.out_channels)
78+
return dict(module='', num_chs=self.conv.out_channels)
7979

8080
def forward(self, x):
8181
shortcut = x
@@ -116,7 +116,7 @@ def feature_info(self, location):
116116
if location == 'expansion': # after SE, input to PW
117117
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
118118
else: # location == 'bottleneck', block output
119-
return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
119+
return dict(module='', num_chs=self.conv_pw.out_channels)
120120

121121
def forward(self, x):
122122
shortcut = x
@@ -173,7 +173,7 @@ def feature_info(self, location):
173173
if location == 'expansion': # after SE, input to PWL
174174
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
175175
else: # location == 'bottleneck', block output
176-
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
176+
return dict(module='', num_chs=self.conv_pwl.out_channels)
177177

178178
def forward(self, x):
179179
shortcut = x
@@ -266,7 +266,7 @@ def feature_info(self, location):
266266
if location == 'expansion': # after SE, before PWL
267267
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
268268
else: # location == 'bottleneck', block output
269-
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
269+
return dict(module='', num_chs=self.conv_pwl.out_channels)
270270

271271
def forward(self, x):
272272
shortcut = x

timm/models/_efficientnet_builder.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,7 @@ def __call__(self, in_chs, model_block_args):
370370
stages = []
371371
if model_block_args[0][0]['stride'] > 1:
372372
# if the first block starts with a stride, we need to extract first level feat from stem
373-
feature_info = dict(
374-
module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
375-
hook_type='forward' if self.feature_location != 'bottleneck' else '')
373+
feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
376374
self.features.append(feature_info)
377375

378376
# outer list of block_args defines the stacks
@@ -418,10 +416,16 @@ def __call__(self, in_chs, model_block_args):
418416
# stash feature module name and channel info for model feature extraction
419417
if extract_features:
420418
feature_info = dict(
421-
stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
422-
module_name = f'blocks.{stack_idx}.{block_idx}'
419+
stage=stack_idx + 1,
420+
reduction=current_stride,
421+
**block.feature_info(self.feature_location),
422+
)
423423
leaf_name = feature_info.get('module', '')
424-
feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
424+
if leaf_name:
425+
feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
426+
else:
427+
assert last_block
428+
feature_info['module'] = f'blocks.{stack_idx}'
425429
self.features.append(feature_info)
426430

427431
total_block_idx += 1 # incr global block idx (across all stacks)

timm/models/_features.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@ class FeatureInfo:
2727

2828
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
2929
prev_reduction = 1
30-
for fi in feature_info:
30+
for i, fi in enumerate(feature_info):
3131
# sanity check the mandatory fields, there may be additional fields depending on the model
3232
assert 'num_chs' in fi and fi['num_chs'] > 0
3333
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
3434
prev_reduction = fi['reduction']
3535
assert 'module' in fi
36+
fi.setdefault('index', i)
3637
self.out_indices = out_indices
3738
self.info = feature_info
3839

timm/models/_features_fx.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from torch import nn
88

9-
from ._features import _get_feature_info
9+
from ._features import _get_feature_info, _get_return_layers
1010

1111
try:
1212
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
@@ -93,9 +93,7 @@ def __init__(self, model, out_indices, out_map=None):
9393
self.feature_info = _get_feature_info(model, out_indices)
9494
if out_map is not None:
9595
assert len(out_map) == len(out_indices)
96-
return_nodes = {
97-
info['module']: out_map[i] if out_map is not None else info['module']
98-
for i, info in enumerate(self.feature_info) if i in out_indices}
96+
return_nodes = _get_return_layers(self.feature_info, out_map)
9997
self.graph_module = create_feature_extractor(model, return_nodes)
10098

10199
def forward(self, x):

timm/models/efficientnet.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def __init__(
232232
)
233233
self.blocks = nn.Sequential(*builder(stem_size, block_args))
234234
self.feature_info = FeatureInfo(builder.features, out_indices)
235-
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
235+
self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()}
236236

237237
efficientnet_init_weights(self)
238238

@@ -268,20 +268,28 @@ def forward(self, x) -> List[torch.Tensor]:
268268

269269

270270
def _create_effnet(variant, pretrained=False, **kwargs):
271-
features_only = False
271+
features_mode = ''
272272
model_cls = EfficientNet
273273
kwargs_filter = None
274274
if kwargs.pop('features_only', False):
275-
features_only = True
276-
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
277-
model_cls = EfficientNetFeatures
275+
if 'feature_cfg' in kwargs:
276+
features_mode = 'cfg'
277+
else:
278+
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
279+
model_cls = EfficientNetFeatures
280+
features_mode = 'cls'
281+
278282
model = build_model_with_cfg(
279-
model_cls, variant, pretrained,
280-
pretrained_strict=not features_only,
283+
model_cls,
284+
variant,
285+
pretrained,
286+
features_only=features_mode == 'cfg',
287+
pretrained_strict=features_mode != 'cls',
281288
kwargs_filter=kwargs_filter,
282-
**kwargs)
283-
if features_only:
284-
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
289+
**kwargs,
290+
)
291+
if features_mode == 'cls':
292+
model.pretrained_cfg = model.default_cfg = pretrained_cfg_for_features(model.pretrained_cfg)
285293
return model
286294

287295

timm/models/hrnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ def __init__(
829829
**kwargs,
830830
)
831831
self.feature_info = FeatureInfo(self.feature_info, out_indices)
832-
self._out_idx = {i for i in out_indices}
832+
self._out_idx = {f['index'] for f in self.feature_info.get_dicts()}
833833

834834
def forward_features(self, x):
835835
assert False, 'Not supported'

timm/models/mobilenetv3.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(
210210
)
211211
self.blocks = nn.Sequential(*builder(stem_size, block_args))
212212
self.feature_info = FeatureInfo(builder.features, out_indices)
213-
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
213+
self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()}
214214

215215
efficientnet_init_weights(self)
216216

@@ -247,21 +247,27 @@ def forward(self, x) -> List[torch.Tensor]:
247247

248248

249249
def _create_mnv3(variant, pretrained=False, **kwargs):
250-
features_only = False
250+
features_mode = ''
251251
model_cls = MobileNetV3
252252
kwargs_filter = None
253253
if kwargs.pop('features_only', False):
254-
features_only = True
255-
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
256-
model_cls = MobileNetV3Features
254+
if 'feature_cfg' in kwargs:
255+
features_mode = 'cfg'
256+
else:
257+
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
258+
model_cls = MobileNetV3Features
259+
features_mode = 'cls'
260+
257261
model = build_model_with_cfg(
258262
model_cls,
259263
variant,
260264
pretrained,
261-
pretrained_strict=not features_only,
265+
features_only=features_mode == 'cfg',
266+
pretrained_strict=features_mode != 'cls',
262267
kwargs_filter=kwargs_filter,
263-
**kwargs)
264-
if features_only:
268+
**kwargs,
269+
)
270+
if features_mode == 'cls':
265271
model.default_cfg = pretrained_cfg_for_features(model.default_cfg)
266272
return model
267273

0 commit comments

Comments
 (0)