55
66The official mindspore code is released and available at
77https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
8+
9+ The official pytorch code is released and available at
10+ https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
811"""
912import math
1013from typing import Optional
1114
1215import torch
1316import torch .nn as nn
1417
15- from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
18+ from timm .data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
1619from timm .layers import Mlp , DropPath , trunc_normal_ , _assert , to_2tuple , resample_abs_pos_embed
1720from ._builder import build_model_with_cfg
1821from ._manipulate import checkpoint
2225__all__ = ['TNT' ] # model_registry will add each entrypoint fn to this
2326
2427
25- def _cfg (url = '' , ** kwargs ):
26- return {
27- 'url' : url ,
28- 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
29- 'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
30- 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
31- 'first_conv' : 'pixel_embed.proj' , 'classifier' : 'head' ,
32- ** kwargs
33- }
34-
35-
36- default_cfgs = {
37- 'tnt_s_patch16_224' : _cfg (
38- url = 'https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar' ,
39- mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
40- ),
41- 'tnt_b_patch16_224' : _cfg (
42- mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
43- ),
44- }
45-
46-
4728class Attention (nn .Module ):
4829 """ Multi-Head Attention
4930 """
@@ -94,6 +75,7 @@ def __init__(
9475 drop_path = 0. ,
9576 act_layer = nn .GELU ,
9677 norm_layer = nn .LayerNorm ,
78+ legacy = False ,
9779 ):
9880 super ().__init__ ()
9981 # Inner transformer
@@ -115,9 +97,14 @@ def __init__(
11597 act_layer = act_layer ,
11698 drop = proj_drop ,
11799 )
118-
119- self .norm1_proj = norm_layer (dim )
120- self .proj = nn .Linear (dim * num_pixel , dim_out , bias = True )
100+ self .legacy = legacy
101+ if self .legacy :
102+ self .norm1_proj = norm_layer (dim )
103+ self .proj = nn .Linear (dim * num_pixel , dim_out , bias = True )
104+ else :
105+ self .norm1_proj = norm_layer (dim * num_pixel )
106+ self .proj = nn .Linear (dim * num_pixel , dim_out , bias = False )
107+ self .norm2_proj = norm_layer (dim_out )
121108
122109 # Outer transformer
123110 self .norm_out = norm_layer (dim_out )
@@ -146,9 +133,16 @@ def forward(self, pixel_embed, patch_embed):
146133 pixel_embed = pixel_embed + self .drop_path (self .mlp_in (self .norm_mlp_in (pixel_embed )))
147134 # outer
148135 B , N , C = patch_embed .size ()
149- patch_embed = torch .cat (
150- [patch_embed [:, 0 :1 ], patch_embed [:, 1 :] + self .proj (self .norm1_proj (pixel_embed ).reshape (B , N - 1 , - 1 ))],
151- dim = 1 )
136+ if self .legacy :
137+ patch_embed = torch .cat ([
138+ patch_embed [:, 0 :1 ], patch_embed [:, 1 :] + \
139+ self .proj (self .norm1_proj (pixel_embed ).reshape (B , N - 1 , - 1 )),
140+ ], dim = 1 )
141+ else :
142+ patch_embed = torch .cat ([
143+ patch_embed [:, 0 :1 ], patch_embed [:, 1 :] + \
144+ self .norm2_proj (self .proj (self .norm1_proj (pixel_embed .reshape (B , N - 1 , - 1 )))),
145+ ], dim = 1 )
152146 patch_embed = patch_embed + self .drop_path (self .attn_out (self .norm_out (patch_embed )))
153147 patch_embed = patch_embed + self .drop_path (self .mlp (self .norm_mlp (patch_embed )))
154148 return pixel_embed , patch_embed
@@ -157,31 +151,41 @@ def forward(self, pixel_embed, patch_embed):
157151class PixelEmbed (nn .Module ):
158152 """ Image to Pixel Embedding
159153 """
160- def __init__ (self , img_size = 224 , patch_size = 16 , in_chans = 3 , in_dim = 48 , stride = 4 ):
154+ def __init__ (self , img_size = 224 , patch_size = 16 , in_chans = 3 , in_dim = 48 , stride = 4 , legacy = False ):
161155 super ().__init__ ()
162156 img_size = to_2tuple (img_size )
163157 patch_size = to_2tuple (patch_size )
164158 # grid_size property necessary for resizing positional embedding
165159 self .grid_size = (img_size [0 ] // patch_size [0 ], img_size [1 ] // patch_size [1 ])
166160 num_patches = (self .grid_size [0 ]) * (self .grid_size [1 ])
167161 self .img_size = img_size
162+ self .patch_size = patch_size
163+ self .legacy = legacy
168164 self .num_patches = num_patches
169165 self .in_dim = in_dim
170166 new_patch_size = [math .ceil (ps / stride ) for ps in patch_size ]
171167 self .new_patch_size = new_patch_size
172168
173169 self .proj = nn .Conv2d (in_chans , self .in_dim , kernel_size = 7 , padding = 3 , stride = stride )
174- self .unfold = nn .Unfold (kernel_size = new_patch_size , stride = new_patch_size )
170+ if self .legacy :
171+ self .unfold = nn .Unfold (kernel_size = new_patch_size , stride = new_patch_size )
172+ else :
173+ self .unfold = nn .Unfold (kernel_size = patch_size , stride = patch_size )
175174
176175 def forward (self , x , pixel_pos ):
177176 B , C , H , W = x .shape
178177 _assert (H == self .img_size [0 ],
179178 f"Input image size ({ H } *{ W } ) doesn't match model ({ self .img_size [0 ]} *{ self .img_size [1 ]} )." )
180179 _assert (W == self .img_size [1 ],
181180 f"Input image size ({ H } *{ W } ) doesn't match model ({ self .img_size [0 ]} *{ self .img_size [1 ]} )." )
182- x = self .proj (x )
183- x = self .unfold (x )
184- x = x .transpose (1 , 2 ).reshape (B * self .num_patches , self .in_dim , self .new_patch_size [0 ], self .new_patch_size [1 ])
181+ if self .legacy :
182+ x = self .proj (x )
183+ x = self .unfold (x )
184+ x = x .transpose (1 , 2 ).reshape (B * self .num_patches , self .in_dim , self .new_patch_size [0 ], self .new_patch_size [1 ])
185+ else :
186+ x = self .unfold (x )
187+ x = x .transpose (1 , 2 ).reshape (B * self .num_patches , C , self .patch_size [0 ], self .patch_size [1 ])
188+ x = self .proj (x )
185189 x = x + pixel_pos
186190 x = x .reshape (B * self .num_patches , self .in_dim , - 1 ).transpose (1 , 2 )
187191 return x
@@ -211,6 +215,7 @@ def __init__(
211215 drop_path_rate = 0. ,
212216 norm_layer = nn .LayerNorm ,
213217 first_stride = 4 ,
218+ legacy = False ,
214219 ):
215220 super ().__init__ ()
216221 assert global_pool in ('' , 'token' , 'avg' )
@@ -225,6 +230,7 @@ def __init__(
225230 in_chans = in_chans ,
226231 in_dim = inner_dim ,
227232 stride = first_stride ,
233+ legacy = legacy ,
228234 )
229235 num_patches = self .pixel_embed .num_patches
230236 self .num_patches = num_patches
@@ -255,6 +261,7 @@ def __init__(
255261 attn_drop = attn_drop_rate ,
256262 drop_path = dpr [i ],
257263 norm_layer = norm_layer ,
264+ legacy = legacy ,
258265 ))
259266 self .blocks = nn .ModuleList (blocks )
260267 self .norm = norm_layer (embed_dim )
@@ -338,14 +345,38 @@ def forward(self, x):
338345
339346
340347def checkpoint_filter_fn (state_dict , model ):
348+ state_dict .pop ('outer_tokens' , None )
349+
350+ out_dict = {}
351+ for k , v in state_dict .items ():
352+ k = k .replace ('outer_pos' , 'patch_pos' )
353+ k = k .replace ('inner_pos' , 'pixel_pos' )
354+ k = k .replace ('patch_embed' , 'pixel_embed' )
355+ k = k .replace ('proj_norm1' , 'norm1_proj' )
356+ k = k .replace ('proj_norm2' , 'norm2_proj' )
357+ k = k .replace ('inner_norm1' , 'norm_in' )
358+ k = k .replace ('inner_attn' , 'attn_in' )
359+ k = k .replace ('inner_norm2' , 'norm_mlp_in' )
360+ k = k .replace ('inner_mlp' , 'mlp_in' )
361+ k = k .replace ('outer_norm1' , 'norm_out' )
362+ k = k .replace ('outer_attn' , 'attn_out' )
363+ k = k .replace ('outer_norm2' , 'norm_mlp' )
364+ k = k .replace ('outer_mlp' , 'mlp' )
365+ if k == 'pixel_pos' :
366+ B , N , C = v .shape
367+ H = W = int (N ** 0.5 )
368+ assert H * W == N
369+ v = v .permute (0 , 2 , 1 ).reshape (B , C , H , W )
370+ out_dict [k ] = v
371+
341372 """ convert patch embedding weight from manual patchify + linear proj to conv"""
342- if state_dict ['patch_pos' ].shape != model .patch_pos .shape :
343- state_dict ['patch_pos' ] = resample_abs_pos_embed (
344- state_dict ['patch_pos' ],
373+ if out_dict ['patch_pos' ].shape != model .patch_pos .shape :
374+ out_dict ['patch_pos' ] = resample_abs_pos_embed (
375+ out_dict ['patch_pos' ],
345376 new_size = model .pixel_embed .grid_size ,
346377 num_prefix_tokens = 1 ,
347378 )
348- return state_dict
379+ return out_dict
349380
350381
351382def _create_tnt (variant , pretrained = False , ** kwargs ):
@@ -359,6 +390,30 @@ def _create_tnt(variant, pretrained=False, **kwargs):
359390 return model
360391
361392
393+ def _cfg (url = '' , ** kwargs ):
394+ return {
395+ 'url' : url ,
396+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
397+ 'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
398+ 'mean' : IMAGENET_INCEPTION_MEAN , 'std' : IMAGENET_INCEPTION_STD ,
399+ 'first_conv' : 'pixel_embed.proj' , 'classifier' : 'head' ,
400+ ** kwargs
401+ }
402+
403+
404+ default_cfgs = {
405+ 'tnt_s_patch16_224' : _cfg (
406+ # hf_hub_id='timm/',
407+ # url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
408+ url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar' ,
409+ ),
410+ 'tnt_b_patch16_224' : _cfg (
411+ # hf_hub_id='timm/',
412+ url = 'https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar' ,
413+ ),
414+ }
415+
416+
362417@register_model
363418def tnt_s_patch16_224 (pretrained = False , ** kwargs ) -> TNT :
364419 model_cfg = dict (
0 commit comments