22import logging
33import os
44from copy import deepcopy
5- from typing import Optional , Dict , Callable , Any , Tuple
5+ from typing import Any , Callable , Dict , List , Optional , Tuple
66
77from torch import nn as nn
88from torch .hub import load_state_dict_from_url
@@ -359,15 +359,15 @@ def build_model_with_cfg(
359359 * pruning config / model adaptation
360360
361361 Args:
362- model_cls (nn.Module) : model class
363- variant (str) : model variant name
364- pretrained (bool) : load pretrained weights
365- pretrained_cfg (dict) : model's pretrained weight/task config
366- model_cfg (Optional[Dict]) : model's architecture config
367- feature_cfg (Optional[Dict] : feature extraction adapter config
368- pretrained_strict (bool) : load pretrained weights strictly
369- pretrained_filter_fn (Optional[Callable]) : filter callable for pretrained weights
370- kwargs_filter (Optional[Tuple]) : kwargs to filter before passing to model
362+ model_cls: model class
363+ variant: model variant name
364+ pretrained: load pretrained weights
365+ pretrained_cfg: model's pretrained weight/task config
366+ model_cfg: model's architecture config
367+ feature_cfg: feature extraction adapter config
368+ pretrained_strict: load pretrained weights strictly
369+ pretrained_filter_fn: filter callable for pretrained weights
370+ kwargs_filter: kwargs to filter before passing to model
371371 **kwargs: model args passed through to model __init__
372372 """
373373 pruned = kwargs .pop ('pruned' , False )
@@ -392,6 +392,8 @@ def build_model_with_cfg(
392392 feature_cfg .setdefault ('out_indices' , (0 , 1 , 2 , 3 , 4 ))
393393 if 'out_indices' in kwargs :
394394 feature_cfg ['out_indices' ] = kwargs .pop ('out_indices' )
395+ if 'feature_cls' in kwargs :
396+ feature_cfg ['feature_cls' ] = kwargs .pop ('feature_cls' )
395397
396398 # Instantiate the model
397399 if model_cfg is None :
@@ -418,24 +420,36 @@ def build_model_with_cfg(
418420
419421 # Wrap the model in a feature extraction module if enabled
420422 if features :
421- feature_cls = FeatureListNet
422- output_fmt = getattr (model , 'output_fmt' , None )
423- if output_fmt is not None :
424- feature_cfg .setdefault ('output_fmt' , output_fmt )
423+ use_getter = False
425424 if 'feature_cls' in feature_cfg :
426425 feature_cls = feature_cfg .pop ('feature_cls' )
427426 if isinstance (feature_cls , str ):
428427 feature_cls = feature_cls .lower ()
428+
429+ # flatten_sequential only valid for some feature extractors
430+ if feature_cls not in ('dict' , 'list' , 'hook' ):
431+ feature_cfg .pop ('flatten_sequential' , None )
432+
429433 if 'hook' in feature_cls :
430434 feature_cls = FeatureHookNet
435+ elif feature_cls == 'list' :
436+ feature_cls = FeatureListNet
431437 elif feature_cls == 'dict' :
432438 feature_cls = FeatureDictNet
433439 elif feature_cls == 'fx' :
434440 feature_cls = FeatureGraphNet
435441 elif feature_cls == 'getter' :
442+ use_getter = True
436443 feature_cls = FeatureGetterNet
437444 else :
438445 assert False , f'Unknown feature class { feature_cls } '
446+ else :
447+ feature_cls = FeatureListNet
448+
449+ output_fmt = getattr (model , 'output_fmt' , None )
450+ if output_fmt is not None and not use_getter : # don't set default for intermediate feat getter
451+ feature_cfg .setdefault ('output_fmt' , output_fmt )
452+
439453 model = feature_cls (model , ** feature_cfg )
440454 model .pretrained_cfg = pretrained_cfg_for_features (pretrained_cfg ) # add back pretrained cfg
441455 model .default_cfg = model .pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)
0 commit comments