4040"""
4141import math
4242from functools import partial
43+ from typing import List , Optional , Union , Tuple
4344
4445import torch
4546import torch .nn as nn
4647
4748from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
4849from timm .layers import PatchEmbed , Mlp , GluMlp , GatedMlp , DropPath , lecun_normal_ , to_2tuple
4950from ._builder import build_model_with_cfg
51+ from ._features import feature_take_indices
5052from ._manipulate import named_apply , checkpoint_seq
5153from ._registry import generate_default_cfgs , register_model , register_model_deprecations
5254
@@ -211,6 +213,7 @@ def __init__(
211213 embed_dim = embed_dim ,
212214 norm_layer = norm_layer if stem_norm else None ,
213215 )
216+ reduction = self .stem .feat_ratio () if hasattr (self .stem , 'feat_ratio' ) else patch_size
214217 # FIXME drop_path (stochastic depth scaling rule or all the same?)
215218 self .blocks = nn .Sequential (* [
216219 block_layer (
@@ -224,6 +227,8 @@ def __init__(
224227 drop_path = drop_path_rate ,
225228 )
226229 for _ in range (num_blocks )])
230+ self .feature_info = [
231+ dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = reduction ) for i in range (num_blocks )]
227232 self .norm = norm_layer (embed_dim )
228233 self .head_drop = nn .Dropout (drop_rate )
229234 self .head = nn .Linear (embed_dim , self .num_classes ) if num_classes > 0 else nn .Identity ()
@@ -257,6 +262,76 @@ def reset_classifier(self, num_classes, global_pool=None):
257262 self .global_pool = global_pool
258263 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
259264
265+ def forward_intermediates (
266+ self ,
267+ x : torch .Tensor ,
268+ indices : Optional [Union [int , List [int ], Tuple [int ]]] = None ,
269+ norm : bool = False ,
270+ stop_early : bool = False ,
271+ output_fmt : str = 'NCHW' ,
272+ intermediates_only : bool = False ,
273+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
274+ """ Forward features that returns intermediates.
275+
276+ Args:
277+ x: Input image tensor
278+ indices: Take last n blocks if int, all if None, select matching indices if sequence
279+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
280+ norm: Apply norm layer to all intermediates
281+ stop_early: Stop iterating over blocks when last desired intermediate hit
282+ output_fmt: Shape of intermediate feature outputs
283+ intermediates_only: Only return intermediate features
284+ Returns:
285+
286+ """
287+ assert output_fmt in ('NCHW' , 'NLC' ), 'Output format must be one of NCHW or NLC.'
288+ reshape = output_fmt == 'NCHW'
289+ intermediates = []
290+ take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
291+
292+ # forward pass
293+ B , _ , height , width = x .shape
294+ x = self .stem (x )
295+
296+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
297+ blocks = self .blocks
298+ else :
299+ blocks = self .blocks [:max_index + 1 ]
300+ for i , blk in enumerate (blocks ):
301+ x = blk (x )
302+ if i in take_indices :
303+ # normalize intermediates with final norm layer if enabled
304+ intermediates .append (self .norm (x ) if norm else x )
305+
306+ # process intermediates
307+ if reshape :
308+ # reshape to BCHW output format
309+ H , W = self .stem .dynamic_feat_size ((height , width ))
310+ intermediates = [y .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous () for y in intermediates ]
311+
312+ if intermediates_only :
313+ return intermediates
314+
315+ x = self .norm (x )
316+
317+ return x , intermediates
318+
319+ def prune_intermediate_layers (
320+ self ,
321+ indices : Union [int , List [int ], Tuple [int ]] = 1 ,
322+ prune_norm : bool = False ,
323+ prune_head : bool = True ,
324+ ):
325+ """ Prune layers not required for specified intermediates.
326+ """
327+ take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
328+ self .blocks = self .blocks [:max_index + 1 ] # truncate blocks
329+ if prune_norm :
330+ self .norm = nn .Identity ()
331+ if prune_head :
332+ self .reset_classifier (0 , '' )
333+ return take_indices
334+
260335 def forward_features (self , x ):
261336 x = self .stem (x )
262337 if self .grad_checkpointing and not torch .jit .is_scripting ():
@@ -330,14 +405,13 @@ def checkpoint_filter_fn(state_dict, model):
330405
331406
332407def _create_mixer (variant , pretrained = False , ** kwargs ):
333- if kwargs .get ('features_only' , None ):
334- raise RuntimeError ('features_only not implemented for MLP-Mixer models.' )
335-
408+ out_indices = kwargs .pop ('out_indices' , 3 )
336409 model = build_model_with_cfg (
337410 MlpMixer ,
338411 variant ,
339412 pretrained ,
340413 pretrained_filter_fn = checkpoint_filter_fn ,
414+ feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' ),
341415 ** kwargs ,
342416 )
343417 return model
0 commit comments