99# Copyright (c) 2015-present, Facebook, Inc.
1010# All rights reserved.
1111from functools import partial
12+ from typing import List , Optional , Tuple , Union
1213
1314import torch
1415import torch .nn as nn
1516
1617from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
1718from timm .layers import PatchEmbed , Mlp , DropPath , trunc_normal_ , use_fused_attn
1819from ._builder import build_model_with_cfg
20+ from ._features import feature_take_indices
1921from ._manipulate import checkpoint_seq
2022from ._registry import register_model , generate_default_cfgs
2123
@@ -246,8 +248,8 @@ def __init__(
246248 in_chans = in_chans ,
247249 embed_dim = embed_dim ,
248250 )
249-
250251 num_patches = self .patch_embed .num_patches
252+ r = self .patch_embed .feat_ratio () if hasattr (self .patch_embed , 'feat_ratio' ) else patch_size
251253
252254 self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim ))
253255 self .pos_embed = nn .Parameter (torch .zeros (1 , num_patches , embed_dim ))
@@ -268,6 +270,7 @@ def __init__(
268270 mlp_block = mlp_block ,
269271 init_values = init_values ,
270272 ) for i in range (depth )])
273+ self .feature_info = [dict (num_chs = embed_dim , reduction = r , module = f'blocks.{ i } ' ) for i in range (depth )]
271274
272275 self .blocks_token_only = nn .ModuleList ([block_layers_token (
273276 dim = embed_dim ,
@@ -283,7 +286,6 @@ def __init__(
283286
284287 self .norm = norm_layer (embed_dim )
285288
286- self .feature_info = [dict (num_chs = embed_dim , reduction = 0 , module = 'head' )]
287289 self .head_drop = nn .Dropout (drop_rate )
288290 self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
289291
@@ -336,6 +338,80 @@ def reset_classifier(self, num_classes, global_pool=None):
336338 self .global_pool = global_pool
337339 self .head = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
338340
341+ def forward_intermediates (
342+ self ,
343+ x : torch .Tensor ,
344+ indices : Optional [Union [int , List [int ], Tuple [int ]]] = None ,
345+ norm : bool = False ,
346+ stop_early : bool = True ,
347+ output_fmt : str = 'NCHW' ,
348+ intermediates_only : bool = False ,
349+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
350+ """ Forward features that returns intermediates.
351+
352+ Args:
353+ x: Input image tensor
354+ indices: Take last n blocks if int, all if None, select matching indices if sequence
355+ norm: Apply norm layer to all intermediates
356+ stop_early: Stop iterating over blocks when last desired intermediate hit
357+ output_fmt: Shape of intermediate feature outputs
358+ intermediates_only: Only return intermediate features
359+ """
360+ assert output_fmt in ('NCHW' , 'NLC' ), 'Output format for ViT features must be one of NCHW or NLC.'
361+ reshape = output_fmt == 'NCHW'
362+ intermediates = []
363+ take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
364+
365+ # forward pass
366+ B , _ , height , width = x .shape
367+ x = self .patch_embed (x )
368+ x = x + self .pos_embed
369+ x = self .pos_drop (x )
370+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
371+ blocks = self .blocks
372+ else :
373+ blocks = self .blocks [:max_index + 1 ]
374+ for i , blk in enumerate (blocks ):
375+ x = blk (x )
376+ if i in take_indices :
377+ # normalize intermediates with final norm layer if enabled
378+ intermediates .append (self .norm (x ) if norm else x )
379+
380+ # process intermediates
381+ if reshape :
382+ # reshape to BCHW output format
383+ H , W = self .patch_embed .dynamic_feat_size ((height , width ))
384+ intermediates = [y .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous () for y in intermediates ]
385+
386+ if intermediates_only :
387+ return intermediates
388+
389+ # NOTE not supporting return of class tokens
390+ cls_tokens = self .cls_token .expand (x .shape [0 ], - 1 , - 1 )
391+ for i , blk in enumerate (self .blocks_token_only ):
392+ cls_tokens = blk (x , cls_tokens )
393+ x = torch .cat ((cls_tokens , x ), dim = 1 )
394+ x = self .norm (x )
395+
396+ return x , intermediates
397+
398+ def prune_intermediate_layers (
399+ self ,
400+ n : Union [int , List [int ], Tuple [int ]] = 1 ,
401+ prune_norm : bool = False ,
402+ prune_head : bool = True ,
403+ ):
404+ """ Prune layers not required for specified intermediates.
405+ """
406+ take_indices , max_index = feature_take_indices (len (self .blocks ), n )
407+ self .blocks = self .blocks [:max_index + 1 ] # truncate blocks
408+ if prune_norm :
409+ self .norm = nn .Identity ()
410+ if prune_head :
411+ self .blocks_token_only = nn .ModuleList () # prune token blocks with head
412+ self .head = nn .Identity ()
413+ return take_indices
414+
339415 def forward_features (self , x ):
340416 x = self .patch_embed (x )
341417 x = x + self .pos_embed
@@ -373,14 +449,13 @@ def checkpoint_filter_fn(state_dict, model=None):
373449
374450
375451def _create_cait (variant , pretrained = False , ** kwargs ):
376- if kwargs .get ('features_only' , None ):
377- raise RuntimeError ('features_only not implemented for Vision Transformer models.' )
378-
452+ out_indices = kwargs .pop ('out_indices' , 3 )
379453 model = build_model_with_cfg (
380454 Cait ,
381455 variant ,
382456 pretrained ,
383457 pretrained_filter_fn = checkpoint_filter_fn ,
458+ feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' ),
384459 ** kwargs ,
385460 )
386461 return model
0 commit comments