1010https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
1111"""
1212import math
13- from typing import Optional
13+ from typing import List , Optional , Tuple , Union
1414
1515import torch
1616import torch .nn as nn
1717
1818from timm .data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
1919from timm .layers import Mlp , DropPath , trunc_normal_ , _assert , to_2tuple , resample_abs_pos_embed
2020from ._builder import build_model_with_cfg
21+ from ._features import feature_take_indices
2122from ._manipulate import checkpoint
2223from ._registry import register_model
2324
@@ -172,7 +173,16 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4,
172173 else :
173174 self .unfold = nn .Unfold (kernel_size = patch_size , stride = patch_size )
174175
175- def forward (self , x , pixel_pos ):
176+ def feat_ratio (self , as_scalar = True ) -> Union [Tuple [int , int ], int ]:
177+ if as_scalar :
178+ return max (self .patch_size )
179+ else :
180+ return self .patch_size
181+
182+ def dynamic_feat_size (self , img_size : Tuple [int , int ]) -> Tuple [int , int ]:
183+ return img_size [0 ] // self .patch_size [0 ], img_size [1 ] // self .patch_size [1 ]
184+
185+ def forward (self , x : torch .Tensor , pixel_pos : torch .Tensor ) -> torch .Tensor :
176186 B , C , H , W = x .shape
177187 _assert (H == self .img_size [0 ],
178188 f"Input image size ({ H } *{ W } ) doesn't match model ({ self .img_size [0 ]} *{ self .img_size [1 ]} )." )
@@ -222,6 +232,7 @@ def __init__(
222232 self .num_classes = num_classes
223233 self .global_pool = global_pool
224234 self .num_features = self .head_hidden_size = self .embed_dim = embed_dim # for consistency with other models
235+ self .num_prefix_tokens = 1
225236 self .grad_checkpointing = False
226237
227238 self .pixel_embed = PixelEmbed (
@@ -233,6 +244,7 @@ def __init__(
233244 legacy = legacy ,
234245 )
235246 num_patches = self .pixel_embed .num_patches
247+ r = self .pixel_embed .feat_ratio () if hasattr (self .pixel_embed , 'feat_ratio' ) else patch_size
236248 self .num_patches = num_patches
237249 new_patch_size = self .pixel_embed .new_patch_size
238250 num_pixel = new_patch_size [0 ] * new_patch_size [1 ]
@@ -264,8 +276,10 @@ def __init__(
264276 legacy = legacy ,
265277 ))
266278 self .blocks = nn .ModuleList (blocks )
279+ self .feature_info = [
280+ dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = r ) for i in range (depth )]
281+
267282 self .norm = norm_layer (embed_dim )
268-
269283 self .head_drop = nn .Dropout (drop_rate )
270284 self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
271285
@@ -313,6 +327,92 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
313327 self .global_pool = global_pool
314328 self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
315329
330+ def forward_intermediates (
331+ self ,
332+ x : torch .Tensor ,
333+ indices : Optional [Union [int , List [int ]]] = None ,
334+ return_prefix_tokens : bool = False ,
335+ norm : bool = False ,
336+ stop_early : bool = False ,
337+ output_fmt : str = 'NCHW' ,
338+ intermediates_only : bool = False ,
339+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
340+ """ Forward features that returns intermediates.
341+
342+ Args:
343+ x: Input image tensor
344+ indices: Take last n blocks if an int, if is a sequence, select by matching indices
345+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
346+ norm: Apply norm layer to all intermediates
347+ stop_early: Stop iterating over blocks when last desired intermediate hit
348+ output_fmt: Shape of intermediate feature outputs
349+ intermediates_only: Only return intermediate features
350+ Returns:
351+
352+ """
353+ assert output_fmt in ('NCHW' , 'NLC' ), 'Output format must be one of NCHW or NLC.'
354+ reshape = output_fmt == 'NCHW'
355+ intermediates = []
356+ take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
357+
358+ # forward pass
359+ B , _ , height , width = x .shape
360+
361+ pixel_embed = self .pixel_embed (x , self .pixel_pos )
362+
363+ patch_embed = self .norm2_proj (self .proj (self .norm1_proj (pixel_embed .reshape (B , self .num_patches , - 1 ))))
364+ patch_embed = torch .cat ((self .cls_token .expand (B , - 1 , - 1 ), patch_embed ), dim = 1 )
365+ patch_embed = patch_embed + self .patch_pos
366+ patch_embed = self .pos_drop (patch_embed )
367+
368+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
369+ blocks = self .blocks
370+ else :
371+ blocks = self .blocks [:max_index + 1 ]
372+
373+ for i , blk in enumerate (blocks ):
374+ pixel_embed , patch_embed = blk (pixel_embed , patch_embed )
375+ if i in take_indices :
376+ # normalize intermediates with final norm layer if enabled
377+ intermediates .append (self .norm (patch_embed ) if norm else patch_embed )
378+
379+ # process intermediates
380+ if self .num_prefix_tokens :
381+ # split prefix (e.g. class, distill) and spatial feature tokens
382+ prefix_tokens = [y [:, 0 :self .num_prefix_tokens ] for y in intermediates ]
383+ intermediates = [y [:, self .num_prefix_tokens :] for y in intermediates ]
384+
385+ if reshape :
386+ # reshape to BCHW output format
387+ H , W = self .pixel_embed .dynamic_feat_size ((height , width ))
388+ intermediates = [y .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous () for y in intermediates ]
389+ if not torch .jit .is_scripting () and return_prefix_tokens :
390+ # return_prefix not support in torchscript due to poor type handling
391+ intermediates = list (zip (intermediates , prefix_tokens ))
392+
393+ if intermediates_only :
394+ return intermediates
395+
396+ patch_embed = self .norm (patch_embed )
397+
398+ return patch_embed , intermediates
399+
400+ def prune_intermediate_layers (
401+ self ,
402+ indices : Union [int , List [int ]] = 1 ,
403+ prune_norm : bool = False ,
404+ prune_head : bool = True ,
405+ ):
406+ """ Prune layers not required for specified intermediates.
407+ """
408+ take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
409+ self .blocks = self .blocks [:max_index + 1 ] # truncate blocks
410+ if prune_norm :
411+ self .norm = nn .Identity ()
412+ if prune_head :
413+ self .reset_classifier (0 , '' )
414+ return take_indices
415+
316416 def forward_features (self , x ):
317417 B = x .shape [0 ]
318418 pixel_embed = self .pixel_embed (x , self .pixel_pos )
@@ -322,19 +422,18 @@ def forward_features(self, x):
322422 patch_embed = patch_embed + self .patch_pos
323423 patch_embed = self .pos_drop (patch_embed )
324424
325- if self . grad_checkpointing and not torch . jit . is_scripting () :
326- for blk in self . blocks :
425+ for blk in self . blocks :
426+ if self . grad_checkpointing and not torch . jit . is_scripting () :
327427 pixel_embed , patch_embed = checkpoint (blk , pixel_embed , patch_embed )
328- else :
329- for blk in self .blocks :
428+ else :
330429 pixel_embed , patch_embed = blk (pixel_embed , patch_embed )
331430
332431 patch_embed = self .norm (patch_embed )
333432 return patch_embed
334433
335434 def forward_head (self , x , pre_logits : bool = False ):
336435 if self .global_pool :
337- x = x [:, 1 :].mean (dim = 1 ) if self .global_pool == 'avg' else x [:, 0 ]
436+ x = x [:, self . num_prefix_tokens :].mean (dim = 1 ) if self .global_pool == 'avg' else x [:, 0 ]
338437 x = self .head_drop (x )
339438 return x if pre_logits else self .head (x )
340439
@@ -344,6 +443,30 @@ def forward(self, x):
344443 return x
345444
346445
446+ def _cfg (url = '' , ** kwargs ):
447+ return {
448+ 'url' : url ,
449+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
450+ 'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
451+ 'mean' : IMAGENET_INCEPTION_MEAN , 'std' : IMAGENET_INCEPTION_STD ,
452+ 'first_conv' : 'pixel_embed.proj' , 'classifier' : 'head' ,
453+ ** kwargs
454+ }
455+
456+
457+ default_cfgs = {
458+ 'tnt_s_patch16_224.in1k' : _cfg (
459+ # hf_hub_id='timm/',
460+ # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
461+ url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar' ,
462+ ),
463+ 'tnt_b_patch16_224.in1k' : _cfg (
464+ # hf_hub_id='timm/',
465+ url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar' ,
466+ ),
467+ }
468+
469+
347470def checkpoint_filter_fn (state_dict , model ):
348471 state_dict .pop ('outer_tokens' , None )
349472
@@ -380,40 +503,15 @@ def checkpoint_filter_fn(state_dict, model):
380503
381504
382505def _create_tnt (variant , pretrained = False , ** kwargs ):
383- if kwargs .get ('features_only' , None ):
384- raise RuntimeError ('features_only not implemented for Vision Transformer models.' )
385-
506+ out_indices = kwargs .pop ('out_indices' , 3 )
386507 model = build_model_with_cfg (
387508 TNT , variant , pretrained ,
388509 pretrained_filter_fn = checkpoint_filter_fn ,
510+ feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' ),
389511 ** kwargs )
390512 return model
391513
392514
393- def _cfg (url = '' , ** kwargs ):
394- return {
395- 'url' : url ,
396- 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
397- 'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
398- 'mean' : IMAGENET_INCEPTION_MEAN , 'std' : IMAGENET_INCEPTION_STD ,
399- 'first_conv' : 'pixel_embed.proj' , 'classifier' : 'head' ,
400- ** kwargs
401- }
402-
403-
404- default_cfgs = {
405- 'tnt_s_patch16_224' : _cfg (
406- # hf_hub_id='timm/',
407- # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
408- url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar' ,
409- ),
410- 'tnt_b_patch16_224' : _cfg (
411- # hf_hub_id='timm/',
412- url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar' ,
413- ),
414- }
415-
416-
417515@register_model
418516def tnt_s_patch16_224 (pretrained = False , ** kwargs ) -> TNT :
419517 model_cfg = dict (
0 commit comments