4040from timm .layers import ClassifierHead , ConvNormAct , BatchNormAct2d , DropPath , AvgPool2dSame , \
4141 create_conv2d , get_act_layer , get_norm_act_layer , get_attn , make_divisible , to_2tuple , EvoNorm2dS0a
4242from ._builder import build_model_with_cfg
43+ from ._features import feature_take_indices
4344from ._manipulate import named_apply , checkpoint_seq
4445from ._registry import generate_default_cfgs , register_model
4546
@@ -948,25 +949,37 @@ def __init__(
948949 stem_norm_acts = [False ] * (num_rep - num_act ) + [True ] * num_act
949950 prev_chs = in_chs
950951 curr_stride = 1
952+ last_feat_idx = - 1
951953 for i , (ch , s , na ) in enumerate (zip (stem_chs , stem_strides , stem_norm_acts )):
952954 layer_fn = layers .conv_norm_act if na else create_conv2d
953955 conv_name = f'conv{ i + 1 } '
954956 if i > 0 and s > 1 :
955- self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat ))
957+ last_feat_idx = i - 1
958+ self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat , stage = 0 ))
956959 self .add_module (conv_name , layer_fn (prev_chs , ch , kernel_size = kernel_size , stride = s ))
957960 prev_chs = ch
958961 curr_stride *= s
959962 prev_feat = conv_name
960963
961964 if pool and 'max' in pool .lower ():
962- self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat ))
965+ last_feat_idx = i
966+ self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat , stage = 0 ))
963967 self .add_module ('pool' , nn .MaxPool2d (3 , 2 , 1 ))
964968 curr_stride *= 2
965969 prev_feat = 'pool'
966970
967- self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat ))
971+ self .last_feat_idx = last_feat_idx if last_feat_idx >= 0 else None
972+ self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat , stage = 0 ))
968973 assert curr_stride == stride
969974
975+ def forward_intermediates (self , x ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
976+ intermediate : Optional [torch .Tensor ] = None
977+ for i , m in enumerate (self ):
978+ x = m (x )
979+ if self .last_feat_idx is not None and i == self .last_feat_idx :
980+ intermediate = x
981+ return x , intermediate
982+
970983
971984def create_byob_stem (
972985 in_chs : int ,
@@ -1008,7 +1021,7 @@ def create_byob_stem(
10081021 if isinstance (stem , Stem ):
10091022 feature_info = [dict (f , module = '.' .join ([feat_prefix , f ['module' ]])) for f in stem .feature_info ]
10101023 else :
1011- feature_info = [dict (num_chs = out_chs , reduction = 2 , module = feat_prefix )]
1024+ feature_info = [dict (num_chs = out_chs , reduction = 2 , module = feat_prefix , stage = 0 )]
10121025 return stem , feature_info
10131026
10141027
@@ -1122,7 +1135,7 @@ def create_byob_stages(
11221135 feat_size = reduce_feat_size (feat_size , stride )
11231136
11241137 stages += [nn .Sequential (* blocks )]
1125- prev_feat = dict (num_chs = prev_chs , reduction = net_stride , module = f'stages.{ stage_idx } ' )
1138+ prev_feat = dict (num_chs = prev_chs , reduction = net_stride , module = f'stages.{ stage_idx } ' , stage = stage_idx + 1 )
11261139
11271140 feature_info .append (prev_feat )
11281141 return nn .Sequential (* stages ), feature_info
@@ -1198,6 +1211,7 @@ def __init__(
11981211 feat_size = feat_size ,
11991212 )
12001213 self .feature_info .extend (stage_feat [:- 1 ])
1214+ reduction = stage_feat [- 1 ]['reduction' ]
12011215
12021216 prev_chs = stage_feat [- 1 ]['num_chs' ]
12031217 if cfg .num_features :
@@ -1207,7 +1221,8 @@ def __init__(
12071221 self .num_features = prev_chs
12081222 self .final_conv = nn .Identity ()
12091223 self .feature_info += [
1210- dict (num_chs = self .num_features , reduction = stage_feat [- 1 ]['reduction' ], module = 'final_conv' )]
1224+ dict (num_chs = self .num_features , reduction = reduction , module = 'final_conv' , stage = len (self .stages ))]
1225+ self .stage_ends = [f ['stage' ] for f in self .feature_info ]
12111226
12121227 self .head = ClassifierHead (
12131228 self .num_features ,
@@ -1241,6 +1256,83 @@ def get_classifier(self):
12411256 def reset_classifier (self , num_classes , global_pool = 'avg' ):
12421257 self .head .reset (num_classes , global_pool )
12431258
1259+ def forward_intermediates (
1260+ self ,
1261+ x : torch .Tensor ,
1262+ indices : Optional [Union [int , List [int ], Tuple [int ]]] = None ,
1263+ norm : bool = False ,
1264+ stop_early : bool = False ,
1265+ output_fmt : str = 'NCHW' ,
1266+ intermediates_only : bool = False ,
1267+ exclude_final_conv : bool = False ,
1268+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
1269+ """ Forward features that returns intermediates.
1270+
1271+ Args:
1272+ x: Input image tensor
1273+ indices: Take last n blocks if int, all if None, select matching indices if sequence
1274+ norm: Apply norm layer to compatible intermediates
1275+ stop_early: Stop iterating over blocks when last desired intermediate hit
1276+ output_fmt: Shape of intermediate feature outputs
1277+ intermediates_only: Only return intermediate features
1278+ exclude_final_conv: Exclude final_conv from last intermediate
1279+ Returns:
1280+
1281+ """
1282+ assert output_fmt in ('NCHW' ,), 'Output shape must be NCHW.'
1283+ intermediates = []
1284+ take_indices , max_index = feature_take_indices (len (self .stage_ends ), indices )
1285+ take_indices = [self .stage_ends [i ] for i in take_indices ]
1286+ max_index = self .stage_ends [max_index ]
1287+ # forward pass
1288+ feat_idx = 0 # stem is index 0
1289+ if hasattr (self .stem , 'forward_intermediates' ):
1290+ # returns last intermediate features in stem (before final stride in stride > 2 stems)
1291+ x , x_inter = self .stem .forward_intermediates (x )
1292+ else :
1293+ x , x_inter = self .stem (x ), None
1294+ if feat_idx in take_indices :
1295+ intermediates .append (x if x_inter is None else x_inter )
1296+ last_idx = self .stage_ends [- 1 ]
1297+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
1298+ stages = self .stages
1299+ else :
1300+ stages = self .stages [:max_index ]
1301+ for stage in stages :
1302+ feat_idx += 1
1303+ x = stage (x )
1304+ if not exclude_final_conv and feat_idx == last_idx :
1305+ # default feature_info for this model uses final_conv as the last feature output (if present)
1306+ x = self .final_conv (x )
1307+ if feat_idx in take_indices :
1308+ intermediates .append (x )
1309+
1310+ if intermediates_only :
1311+ return intermediates
1312+
1313+ if exclude_final_conv and feat_idx == last_idx :
1314+ x = self .final_conv (x )
1315+
1316+ return x , intermediates
1317+
1318+ def prune_intermediate_layers (
1319+ self ,
1320+ indices : Union [int , List [int ], Tuple [int ]] = 1 ,
1321+ prune_norm : bool = False ,
1322+ prune_head : bool = True ,
1323+ ):
1324+ """ Prune layers not required for specified intermediates.
1325+ """
1326+ take_indices , max_index = feature_take_indices (len (self .stage_ends ), indices )
1327+ max_index = self .stage_ends [max_index ]
1328+ self .stages = self .stages [:max_index ] # truncate blocks w/ stem as idx 0
1329+ if max_index < self .stage_ends [- 1 ]:
1330+ self .final_conv = nn .Identity ()
1331+ if prune_head :
1332+ self .reset_classifier (0 , '' )
1333+ return take_indices
1334+
1335+
12441336 def forward_features (self , x ):
12451337 x = self .stem (x )
12461338 if self .grad_checkpointing and not torch .jit .is_scripting ():
0 commit comments