2222from ._manipulate import checkpoint
2323from ._registry import generate_default_cfgs , register_model
2424
25-
2625__all__ = ['TNT' ] # model_registry will add each entrypoint fn to this
2726
2827
2928class Attention (nn .Module ):
3029 """ Multi-Head Attention
3130 """
31+
3232 def __init__ (self , dim , hidden_dim , num_heads = 8 , qkv_bias = False , attn_drop = 0. , proj_drop = 0. ):
3333 super ().__init__ ()
3434 self .hidden_dim = hidden_dim
@@ -46,7 +46,7 @@ def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., p
4646 def forward (self , x ):
4747 B , N , C = x .shape
4848 qk = self .qk (x ).reshape (B , N , 2 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
49- q , k = qk .unbind (0 ) # make torchscript happy (cannot use tensor as tuple)
49+ q , k = qk .unbind (0 ) # make torchscript happy (cannot use tensor as tuple)
5050 v = self .v (x ).reshape (B , N , self .num_heads , - 1 ).permute (0 , 2 , 1 , 3 )
5151
5252 attn = (q @ k .transpose (- 2 , - 1 )) * self .scale
@@ -62,6 +62,7 @@ def forward(self, x):
6262class Block (nn .Module ):
6363 """ TNT Block
6464 """
65+
6566 def __init__ (
6667 self ,
6768 dim ,
@@ -89,7 +90,7 @@ def __init__(
8990 attn_drop = attn_drop ,
9091 proj_drop = proj_drop ,
9192 )
92-
93+
9394 self .norm_mlp_in = norm_layer (dim )
9495 self .mlp_in = Mlp (
9596 in_features = dim ,
@@ -118,7 +119,7 @@ def __init__(
118119 proj_drop = proj_drop ,
119120 )
120121 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
121-
122+
122123 self .norm_mlp = norm_layer (dim_out )
123124 self .mlp = Mlp (
124125 in_features = dim_out ,
@@ -136,13 +137,13 @@ def forward(self, pixel_embed, patch_embed):
136137 B , N , C = patch_embed .size ()
137138 if self .legacy :
138139 patch_embed = torch .cat ([
139- patch_embed [:, 0 :1 ], patch_embed [:, 1 :] + \
140- self .proj (self .norm1_proj (pixel_embed ).reshape (B , N - 1 , - 1 )),
140+ patch_embed [:, 0 :1 ],
141+ patch_embed [:, 1 :] + self .proj (self .norm1_proj (pixel_embed ).reshape (B , N - 1 , - 1 )),
141142 ], dim = 1 )
142143 else :
143144 patch_embed = torch .cat ([
144- patch_embed [:, 0 :1 ], patch_embed [:, 1 :] + \
145- self .norm2_proj (self .proj (self .norm1_proj (pixel_embed .reshape (B , N - 1 , - 1 )))),
145+ patch_embed [:, 0 :1 ],
146+ patch_embed [:, 1 :] + self .norm2_proj (self .proj (self .norm1_proj (pixel_embed .reshape (B , N - 1 , - 1 )))),
146147 ], dim = 1 )
147148 patch_embed = patch_embed + self .drop_path (self .attn_out (self .norm_out (patch_embed )))
148149 patch_embed = patch_embed + self .drop_path (self .mlp (self .norm_mlp (patch_embed )))
@@ -152,7 +153,16 @@ def forward(self, pixel_embed, patch_embed):
152153class PixelEmbed (nn .Module ):
153154 """ Image to Pixel Embedding
154155 """
155- def __init__ (self , img_size = 224 , patch_size = 16 , in_chans = 3 , in_dim = 48 , stride = 4 , legacy = False ):
156+
157+ def __init__ (
158+ self ,
159+ img_size = 224 ,
160+ patch_size = 16 ,
161+ in_chans = 3 ,
162+ in_dim = 48 ,
163+ stride = 4 ,
164+ legacy = False ,
165+ ):
156166 super ().__init__ ()
157167 img_size = to_2tuple (img_size )
158168 patch_size = to_2tuple (patch_size )
@@ -184,14 +194,17 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
184194
185195 def forward (self , x : torch .Tensor , pixel_pos : torch .Tensor ) -> torch .Tensor :
186196 B , C , H , W = x .shape
187- _assert (H == self .img_size [0 ],
197+ _assert (
198+ H == self .img_size [0 ],
188199 f"Input image size ({ H } *{ W } ) doesn't match model ({ self .img_size [0 ]} *{ self .img_size [1 ]} )." )
189- _assert (W == self .img_size [1 ],
200+ _assert (
201+ W == self .img_size [1 ],
190202 f"Input image size ({ H } *{ W } ) doesn't match model ({ self .img_size [0 ]} *{ self .img_size [1 ]} )." )
191203 if self .legacy :
192204 x = self .proj (x )
193205 x = self .unfold (x )
194- x = x .transpose (1 , 2 ).reshape (B * self .num_patches , self .in_dim , self .new_patch_size [0 ], self .new_patch_size [1 ])
206+ x = x .transpose (1 , 2 ).reshape (
207+ B * self .num_patches , self .in_dim , self .new_patch_size [0 ], self .new_patch_size [1 ])
195208 else :
196209 x = self .unfold (x )
197210 x = x .transpose (1 , 2 ).reshape (B * self .num_patches , C , self .patch_size [0 ], self .patch_size [1 ])
@@ -204,6 +217,7 @@ def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
204217class TNT (nn .Module ):
205218 """ Transformer in Transformer - https://arxiv.org/abs/2103.00112
206219 """
220+
207221 def __init__ (
208222 self ,
209223 img_size = 224 ,
@@ -248,7 +262,7 @@ def __init__(
248262 self .num_patches = num_patches
249263 new_patch_size = self .pixel_embed .new_patch_size
250264 num_pixel = new_patch_size [0 ] * new_patch_size [1 ]
251-
265+
252266 self .norm1_proj = norm_layer (num_pixel * inner_dim )
253267 self .proj = nn .Linear (num_pixel * inner_dim , embed_dim )
254268 self .norm2_proj = norm_layer (embed_dim )
@@ -278,7 +292,7 @@ def __init__(
278292 self .blocks = nn .ModuleList (blocks )
279293 self .feature_info = [
280294 dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = r ) for i in range (depth )]
281-
295+
282296 self .norm = norm_layer (embed_dim )
283297 self .head_drop = nn .Dropout (drop_rate )
284298 self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
@@ -359,7 +373,7 @@ def forward_intermediates(
359373 B , _ , height , width = x .shape
360374
361375 pixel_embed = self .pixel_embed (x , self .pixel_pos )
362-
376+
363377 patch_embed = self .norm2_proj (self .proj (self .norm1_proj (pixel_embed .reshape (B , self .num_patches , - 1 ))))
364378 patch_embed = torch .cat ((self .cls_token .expand (B , - 1 , - 1 ), patch_embed ), dim = 1 )
365379 patch_embed = patch_embed + self .patch_pos
@@ -381,7 +395,7 @@ def forward_intermediates(
381395 # split prefix (e.g. class, distill) and spatial feature tokens
382396 prefix_tokens = [y [:, 0 :self .num_prefix_tokens ] for y in intermediates ]
383397 intermediates = [y [:, self .num_prefix_tokens :] for y in intermediates ]
384-
398+
385399 if reshape :
386400 # reshape to BCHW output format
387401 H , W = self .pixel_embed .dynamic_feat_size ((height , width ))
@@ -416,7 +430,7 @@ def prune_intermediate_layers(
416430 def forward_features (self , x ):
417431 B = x .shape [0 ]
418432 pixel_embed = self .pixel_embed (x , self .pixel_pos )
419-
433+
420434 patch_embed = self .norm2_proj (self .proj (self .norm1_proj (pixel_embed .reshape (B , self .num_patches , - 1 ))))
421435 patch_embed = torch .cat ((self .cls_token .expand (B , - 1 , - 1 ), patch_embed ), dim = 1 )
422436 patch_embed = patch_embed + self .patch_pos
@@ -458,42 +472,47 @@ def _cfg(url='', **kwargs):
458472
459473
460474default_cfgs = generate_default_cfgs ({
475+ 'tnt_s_legacy_patch16_224.in1k' : _cfg (
476+ hf_hub_id = 'timm/' ,
477+ #url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
478+ ),
461479 'tnt_s_patch16_224.in1k' : _cfg (
462- # hf_hub_id='timm/',
463- # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
464- url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar' ,
480+ hf_hub_id = 'timm/' ,
481+ #url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
465482 ),
466483 'tnt_b_patch16_224.in1k' : _cfg (
467- # hf_hub_id='timm/',
468- url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar' ,
484+ hf_hub_id = 'timm/' ,
485+ # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
469486 ),
470487})
471488
472489
473490def checkpoint_filter_fn (state_dict , model ):
474491 state_dict .pop ('outer_tokens' , None )
475-
476- out_dict = {}
477- for k , v in state_dict .items ():
478- k = k .replace ('outer_pos' , 'patch_pos' )
479- k = k .replace ('inner_pos' , 'pixel_pos' )
480- k = k .replace ('patch_embed' , 'pixel_embed' )
481- k = k .replace ('proj_norm1' , 'norm1_proj' )
482- k = k .replace ('proj_norm2' , 'norm2_proj' )
483- k = k .replace ('inner_norm1' , 'norm_in' )
484- k = k .replace ('inner_attn' , 'attn_in' )
485- k = k .replace ('inner_norm2' , 'norm_mlp_in' )
486- k = k .replace ('inner_mlp' , 'mlp_in' )
487- k = k .replace ('outer_norm1' , 'norm_out' )
488- k = k .replace ('outer_attn' , 'attn_out' )
489- k = k .replace ('outer_norm2' , 'norm_mlp' )
490- k = k .replace ('outer_mlp' , 'mlp' )
491- if k == 'pixel_pos' and model .pixel_embed .legacy == False :
492- B , N , C = v .shape
493- H = W = int (N ** 0.5 )
494- assert H * W == N
495- v = v .permute (0 , 2 , 1 ).reshape (B , C , H , W )
496- out_dict [k ] = v
492+ if 'patch_pos' in state_dict :
493+ out_dict = state_dict
494+ else :
495+ out_dict = {}
496+ for k , v in state_dict .items ():
497+ k = k .replace ('outer_pos' , 'patch_pos' )
498+ k = k .replace ('inner_pos' , 'pixel_pos' )
499+ k = k .replace ('patch_embed' , 'pixel_embed' )
500+ k = k .replace ('proj_norm1' , 'norm1_proj' )
501+ k = k .replace ('proj_norm2' , 'norm2_proj' )
502+ k = k .replace ('inner_norm1' , 'norm_in' )
503+ k = k .replace ('inner_attn' , 'attn_in' )
504+ k = k .replace ('inner_norm2' , 'norm_mlp_in' )
505+ k = k .replace ('inner_mlp' , 'mlp_in' )
506+ k = k .replace ('outer_norm1' , 'norm_out' )
507+ k = k .replace ('outer_attn' , 'attn_out' )
508+ k = k .replace ('outer_norm2' , 'norm_mlp' )
509+ k = k .replace ('outer_mlp' , 'mlp' )
510+ if k == 'pixel_pos' and model .pixel_embed .legacy == False :
511+ B , N , C = v .shape
512+ H = W = int (N ** 0.5 )
513+ assert H * W == N
514+ v = v .permute (0 , 2 , 1 ).reshape (B , C , H , W )
515+ out_dict [k ] = v
497516
498517 """ convert patch embedding weight from manual patchify + linear proj to conv"""
499518 if out_dict ['patch_pos' ].shape != model .patch_pos .shape :
@@ -515,6 +534,15 @@ def _create_tnt(variant, pretrained=False, **kwargs):
515534 return model
516535
517536
537+ @register_model
538+ def tnt_s_legacy_patch16_224 (pretrained = False , ** kwargs ) -> TNT :
539+ model_cfg = dict (
540+ patch_size = 16 , embed_dim = 384 , inner_dim = 24 , depth = 12 , num_heads_outer = 6 ,
541+ qkv_bias = False , legacy = True )
542+ model = _create_tnt ('tnt_s_legacy_patch16_224' , pretrained = pretrained , ** dict (model_cfg , ** kwargs ))
543+ return model
544+
545+
518546@register_model
519547def tnt_s_patch16_224 (pretrained = False , ** kwargs ) -> TNT :
520548 model_cfg = dict (
0 commit comments