@@ -326,7 +326,6 @@ def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3,
326326 # Stem
327327 if not fix_stem :
328328 stem_size = round_channels (stem_size , channel_multiplier , channel_divisor , channel_min )
329- print (stem_size )
330329 self .conv_stem = create_conv2d (self ._in_chs , stem_size , 3 , stride = 2 , padding = pad_type )
331330 self .bn1 = norm_layer (stem_size , ** norm_kwargs )
332331 self .act1 = act_layer (inplace = True )
@@ -393,7 +392,7 @@ class EfficientNetFeatures(nn.Module):
393392 and object detection models.
394393 """
395394
396- def __init__ (self , block_args , out_indices = (0 , 1 , 2 , 3 , 4 ), feature_location = 'pre_pwl ' ,
395+ def __init__ (self , block_args , out_indices = (0 , 1 , 2 , 3 , 4 ), feature_location = 'bottleneck ' ,
397396 in_chans = 3 , stem_size = 32 , channel_multiplier = 1.0 , channel_divisor = 8 , channel_min = None ,
398397 output_stride = 32 , pad_type = '' , fix_stem = False , act_layer = nn .ReLU , drop_rate = 0. , drop_path_rate = 0. ,
399398 se_kwargs = None , norm_layer = nn .BatchNorm2d , norm_kwargs = None ):
@@ -404,6 +403,7 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
404403 num_stages = max (out_indices ) + 1
405404
406405 self .out_indices = out_indices
406+ self .feature_location = feature_location
407407 self .drop_rate = drop_rate
408408 self ._in_chs = in_chans
409409
@@ -420,34 +420,56 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
420420 channel_multiplier , channel_divisor , channel_min , output_stride , pad_type , act_layer , se_kwargs ,
421421 norm_layer , norm_kwargs , drop_path_rate , feature_location = feature_location , verbose = _DEBUG )
422422 self .blocks = nn .Sequential (* builder (self ._in_chs , block_args ))
423- self .feature_info = builder .features # builder provides info about feature channels for each block
423+ self ._feature_info = builder .features # builder provides info about feature channels for each block
424+ self ._stage_to_feature_idx = {
425+ v ['stage_idx' ]: fi for fi , v in self ._feature_info .items () if fi in self .out_indices }
424426 self ._in_chs = builder .in_chs
425427
426428 efficientnet_init_weights (self )
427429 if _DEBUG :
428- for k , v in self .feature_info .items ():
430+ for k , v in self ._feature_info .items ():
429431 print ('Feature idx: {}: Name: {}, Channels: {}' .format (k , v ['name' ], v ['num_chs' ]))
430432
431433 # Register feature extraction hooks with FeatureHooks helper
432- hook_type = 'forward_pre' if feature_location == 'pre_pwl' else 'forward'
433- hooks = [dict (name = self .feature_info [idx ]['name' ], type = hook_type ) for idx in out_indices ]
434- self .feature_hooks = FeatureHooks (hooks , self .named_modules ())
434+ self .feature_hooks = None
435+ if feature_location != 'bottleneck' :
436+ hooks = [dict (
437+ name = self ._feature_info [idx ]['module' ],
438+ type = self ._feature_info [idx ]['hook_type' ]) for idx in out_indices ]
439+ self .feature_hooks = FeatureHooks (hooks , self .named_modules ())
435440
436441 def feature_channels (self , idx = None ):
437442 """ Feature Channel Shortcut
438443 Returns feature channel count for each output index if idx == None. If idx is an integer, will
439444 return feature channel count for that feature block index (independent of out_indices setting).
440445 """
441446 if isinstance (idx , int ):
442- return self .feature_info [idx ]['num_chs' ]
443- return [self .feature_info [i ]['num_chs' ] for i in self .out_indices ]
447+ return self ._feature_info [idx ]['num_chs' ]
448+ return [self ._feature_info [i ]['num_chs' ] for i in self .out_indices ]
449+
450+ def feature_info (self , idx = None ):
451+ """ Feature Channel Shortcut
452+ Returns feature channel count for each output index if idx == None. If idx is an integer, will
453+ return feature channel count for that feature block index (independent of out_indices setting).
454+ """
455+ if isinstance (idx , int ):
456+ return self ._feature_info [idx ]
457+ return [self ._feature_info [i ] for i in self .out_indices ]
444458
445459 def forward (self , x ):
446460 x = self .conv_stem (x )
447461 x = self .bn1 (x )
448462 x = self .act1 (x )
449- self .blocks (x )
450- return self .feature_hooks .get_output (x .device )
463+ if self .feature_hooks is None :
464+ features = []
465+ for i , b in enumerate (self .blocks ):
466+ x = b (x )
467+ if i in self ._stage_to_feature_idx :
468+ features .append (x )
469+ return features
470+ else :
471+ self .blocks (x )
472+ return self .feature_hooks .get_output (x .device )
451473
452474
453475def _create_model (model_kwargs , default_cfg , pretrained = False ):
0 commit comments