1212
1313Modifications and timm support by / Copyright 2022, Ross Wightman
1414"""
15- from typing import Dict
15+ from typing import Dict , List , Tuple , Union
1616
1717import torch
1818import torch .nn as nn
1919
2020from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
2121from timm .layers import DropPath , trunc_normal_ , to_2tuple , Mlp , ndgrid
2222from ._builder import build_model_with_cfg
23+ from ._features import feature_take_indices
2324from ._manipulate import checkpoint_seq
2425from ._registry import generate_default_cfgs , register_model
2526
@@ -382,16 +383,19 @@ def __init__(
382383 prev_dim = embed_dims [0 ]
383384
384385 # stochastic depth decay rule
386+ self .num_stages = len (depths )
387+ last_stage = self .num_stages - 1
385388 dpr = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (depths )).split (depths )]
386- downsamples = downsamples or (False ,) + (True ,) * (len ( depths ) - 1 )
389+ downsamples = downsamples or (False ,) + (True ,) * (self . num_stages - 1 )
387390 stages = []
388- for i in range (len (depths )):
391+ self .feature_info = []
392+ for i in range (self .num_stages ):
389393 stage = EfficientFormerStage (
390394 prev_dim ,
391395 embed_dims [i ],
392396 depths [i ],
393397 downsample = downsamples [i ],
394- num_vit = num_vit if i == 3 else 0 ,
398+ num_vit = num_vit if i == last_stage else 0 ,
395399 pool_size = pool_size ,
396400 mlp_ratio = mlp_ratios ,
397401 act_layer = act_layer ,
@@ -403,7 +407,7 @@ def __init__(
403407 )
404408 prev_dim = embed_dims [i ]
405409 stages .append (stage )
406-
410+ self . feature_info += [ dict ( num_chs = embed_dims [ i ], reduction = 2 ** ( 1 + i ), module = f'stages. { i } ' )]
407411 self .stages = nn .Sequential (* stages )
408412
409413 # Classifier head
@@ -456,6 +460,76 @@ def reset_classifier(self, num_classes, global_pool=None):
456460 def set_distilled_training (self , enable = True ):
457461 self .distilled_training = enable
458462
463+ def forward_intermediates (
464+ self ,
465+ x : torch .Tensor ,
466+ indices : Union [int , List [int ], Tuple [int ]] = None ,
467+ norm : bool = False ,
468+ stop_early : bool = False ,
469+ output_fmt : str = 'NCHW' ,
470+ intermediates_only : bool = False ,
471+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
472+ """ Forward features that returns intermediates.
473+
474+ Args:
475+ x: Input image tensor
476+ indices: Take last n blocks if int, all if None, select matching indices if sequence
477+ norm: Apply norm layer to compatible intermediates
478+ stop_early: Stop iterating over blocks when last desired intermediate hit
479+ output_fmt: Shape of intermediate feature outputs
480+ intermediates_only: Only return intermediate features
481+ Returns:
482+
483+ """
484+ assert output_fmt in ('NCHW' ,), 'Output shape must be NCHW.'
485+ intermediates = []
486+ take_indices , max_index = feature_take_indices (len (self .stages ), indices )
487+
488+ # forward pass
489+ x = self .stem (x )
490+ B , C , H , W = x .shape
491+
492+ last_idx = self .num_stages - 1
493+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
494+ stages = self .stages
495+ else :
496+ stages = self .stages [:max_index + 1 ]
497+ feat_idx = 0
498+ for feat_idx , stage in enumerate (stages ):
499+ x = stage (x )
500+ if feat_idx < last_idx :
501+ B , C , H , W = x .shape
502+ if feat_idx in take_indices :
503+ if feat_idx == last_idx :
504+ x_inter = self .norm (x ) if norm else x
505+ intermediates .append (x_inter .reshape (B , H // 2 , W // 2 , - 1 ).permute (0 , 3 , 1 , 2 ))
506+ else :
507+ intermediates .append (x )
508+
509+ if intermediates_only :
510+ return intermediates
511+
512+ if feat_idx == last_idx :
513+ x = self .norm (x )
514+
515+ return x , intermediates
516+
517+ def prune_intermediate_layers (
518+ self ,
519+ indices : Union [int , List [int ], Tuple [int ]] = 1 ,
520+ prune_norm : bool = False ,
521+ prune_head : bool = True ,
522+ ):
523+ """ Prune layers not required for specified intermediates.
524+ """
525+ take_indices , max_index = feature_take_indices (len (self .stages ), indices )
526+ self .stages = self .stages [:max_index + 1 ] # truncate blocks w/ stem as idx 0
527+ if prune_norm :
528+ self .norm = nn .Identity ()
529+ if prune_head :
530+ self .reset_classifier (0 , '' )
531+ return take_indices
532+
459533 def forward_features (self , x ):
460534 x = self .stem (x )
461535 x = self .stages (x )
@@ -534,13 +608,13 @@ def _cfg(url='', **kwargs):
534608
535609
536610def _create_efficientformer (variant , pretrained = False , ** kwargs ):
537- if kwargs .get ('features_only' , None ):
538- raise RuntimeError ('features_only not implemented for EfficientFormer models.' )
539-
611+ out_indices = kwargs .pop ('out_indices' , 4 )
540612 model = build_model_with_cfg (
541613 EfficientFormer , variant , pretrained ,
542614 pretrained_filter_fn = _checkpoint_filter_fn ,
543- ** kwargs )
615+ feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' ),
616+ ** kwargs ,
617+ )
544618 return model
545619
546620
0 commit comments