1414from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
1515from timm .models .helpers import load_pretrained
1616from timm .models .layers import Mlp , DropPath , trunc_normal_
17+ from timm .models .layers .helpers import to_2tuple
1718from timm .models .registry import register_model
19+ from timm .models .vision_transformer import resize_pos_embed
1820
1921
2022def _cfg (url = '' , ** kwargs ):
@@ -118,23 +120,27 @@ class PixelEmbed(nn.Module):
118120 """
119121 def __init__ (self , img_size = 224 , patch_size = 16 , in_chans = 3 , in_dim = 48 , stride = 4 ):
120122 super ().__init__ ()
121- num_patches = (img_size // patch_size ) ** 2
123+ img_size = to_2tuple (img_size )
124+ patch_size = to_2tuple (patch_size )
125+ # grid_size property necessary for resizing positional embedding
126+ self .grid_size = (img_size [0 ] // patch_size [0 ], img_size [1 ] // patch_size [1 ])
127+ num_patches = (self .grid_size [0 ]) * (self .grid_size [1 ])
122128 self .img_size = img_size
123129 self .num_patches = num_patches
124130 self .in_dim = in_dim
125- new_patch_size = math .ceil (patch_size / stride )
131+ new_patch_size = [ math .ceil (ps / stride ) for ps in patch_size ]
126132 self .new_patch_size = new_patch_size
127133
128134 self .proj = nn .Conv2d (in_chans , self .in_dim , kernel_size = 7 , padding = 3 , stride = stride )
129135 self .unfold = nn .Unfold (kernel_size = new_patch_size , stride = new_patch_size )
130136
131137 def forward (self , x , pixel_pos ):
132138 B , C , H , W = x .shape
133- assert H == self .img_size and W == self .img_size , \
134- f"Input image size ({ H } *{ W } ) doesn't match model ({ self .img_size } *{ self .img_size } )."
139+ assert H == self .img_size [ 0 ] and W == self .img_size [ 1 ] , \
140+ f"Input image size ({ H } *{ W } ) doesn't match model ({ self .img_size [ 0 ] } *{ self .img_size [ 1 ] } )."
135141 x = self .proj (x )
136142 x = self .unfold (x )
137- x = x .transpose (1 , 2 ).reshape (B * self .num_patches , self .in_dim , self .new_patch_size , self .new_patch_size )
143+ x = x .transpose (1 , 2 ).reshape (B * self .num_patches , self .in_dim , self .new_patch_size [ 0 ] , self .new_patch_size [ 1 ] )
138144 x = x + pixel_pos
139145 x = x .reshape (B * self .num_patches , self .in_dim , - 1 ).transpose (1 , 2 )
140146 return x
@@ -155,15 +161,15 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
155161 num_patches = self .pixel_embed .num_patches
156162 self .num_patches = num_patches
157163 new_patch_size = self .pixel_embed .new_patch_size
158- num_pixel = new_patch_size ** 2
164+ num_pixel = new_patch_size [ 0 ] * new_patch_size [ 1 ]
159165
160166 self .norm1_proj = norm_layer (num_pixel * in_dim )
161167 self .proj = nn .Linear (num_pixel * in_dim , embed_dim )
162168 self .norm2_proj = norm_layer (embed_dim )
163169
164170 self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim ))
165171 self .patch_pos = nn .Parameter (torch .zeros (1 , num_patches + 1 , embed_dim ))
166- self .pixel_pos = nn .Parameter (torch .zeros (1 , in_dim , new_patch_size , new_patch_size ))
172+ self .pixel_pos = nn .Parameter (torch .zeros (1 , in_dim , new_patch_size [ 0 ] , new_patch_size [ 1 ] ))
167173 self .pos_drop = nn .Dropout (p = drop_rate )
168174
169175 dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )] # stochastic depth decay rule
@@ -224,14 +230,23 @@ def forward(self, x):
224230 return x
225231
226232
233+ def checkpoint_filter_fn (state_dict , model ):
234+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
235+ if state_dict ['patch_pos' ].shape != model .patch_pos .shape :
236+ state_dict ['patch_pos' ] = resize_pos_embed (state_dict ['patch_pos' ],
237+ model .patch_pos , getattr (model , 'num_tokens' , 1 ), model .pixel_embed .grid_size )
238+ return state_dict
239+
240+
227241@register_model
228242def tnt_s_patch16_224 (pretrained = False , ** kwargs ):
229243 model = TNT (patch_size = 16 , embed_dim = 384 , in_dim = 24 , depth = 12 , num_heads = 6 , in_num_head = 4 ,
230244 qkv_bias = False , ** kwargs )
231245 model .default_cfg = default_cfgs ['tnt_s_patch16_224' ]
232246 if pretrained :
233247 load_pretrained (
234- model , num_classes = model .num_classes , in_chans = kwargs .get ('in_chans' , 3 ))
248+ model , num_classes = model .num_classes , in_chans = kwargs .get ('in_chans' , 3 ),
249+ filter_fn = checkpoint_filter_fn )
235250 return model
236251
237252
0 commit comments