Skip to content

Commit b37f0f7

Browse files
committed
Update tnt.py
1 parent c8c4f25 commit b37f0f7

File tree

1 file changed

+93
-38
lines changed

1 file changed

+93
-38
lines changed

timm/models/tnt.py

Lines changed: 93 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
66
The official mindspore code is released and available at
77
https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
8+
9+
The official pytorch code is released and available at
10+
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch
811
"""
912
import math
1013
from typing import Optional
1114

1215
import torch
1316
import torch.nn as nn
1417

15-
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18+
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1619
from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple, resample_abs_pos_embed
1720
from ._builder import build_model_with_cfg
1821
from ._manipulate import checkpoint
@@ -22,28 +25,6 @@
2225
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
2326

2427

25-
def _cfg(url='', **kwargs):
26-
return {
27-
'url': url,
28-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
29-
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
30-
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
31-
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
32-
**kwargs
33-
}
34-
35-
36-
default_cfgs = {
37-
'tnt_s_patch16_224': _cfg(
38-
url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
39-
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
40-
),
41-
'tnt_b_patch16_224': _cfg(
42-
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
43-
),
44-
}
45-
46-
4728
class Attention(nn.Module):
4829
""" Multi-Head Attention
4930
"""
@@ -94,6 +75,7 @@ def __init__(
9475
drop_path=0.,
9576
act_layer=nn.GELU,
9677
norm_layer=nn.LayerNorm,
78+
legacy=False,
9779
):
9880
super().__init__()
9981
# Inner transformer
@@ -115,9 +97,14 @@ def __init__(
11597
act_layer=act_layer,
11698
drop=proj_drop,
11799
)
118-
119-
self.norm1_proj = norm_layer(dim)
120-
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
100+
self.legacy = legacy
101+
if self.legacy:
102+
self.norm1_proj = norm_layer(dim)
103+
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
104+
else:
105+
self.norm1_proj = norm_layer(dim * num_pixel)
106+
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False)
107+
self.norm2_proj = norm_layer(dim_out)
121108

122109
# Outer transformer
123110
self.norm_out = norm_layer(dim_out)
@@ -146,9 +133,16 @@ def forward(self, pixel_embed, patch_embed):
146133
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
147134
# outer
148135
B, N, C = patch_embed.size()
149-
patch_embed = torch.cat(
150-
[patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
151-
dim=1)
136+
if self.legacy:
137+
patch_embed = torch.cat([
138+
patch_embed[:, 0:1], patch_embed[:, 1:] + \
139+
self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
140+
], dim=1)
141+
else:
142+
patch_embed = torch.cat([
143+
patch_embed[:, 0:1], patch_embed[:, 1:] + \
144+
self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
145+
], dim=1)
152146
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
153147
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
154148
return pixel_embed, patch_embed
@@ -157,31 +151,41 @@ def forward(self, pixel_embed, patch_embed):
157151
class PixelEmbed(nn.Module):
158152
""" Image to Pixel Embedding
159153
"""
160-
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
154+
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4, legacy=False):
161155
super().__init__()
162156
img_size = to_2tuple(img_size)
163157
patch_size = to_2tuple(patch_size)
164158
# grid_size property necessary for resizing positional embedding
165159
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
166160
num_patches = (self.grid_size[0]) * (self.grid_size[1])
167161
self.img_size = img_size
162+
self.patch_size = patch_size
163+
self.legacy = legacy
168164
self.num_patches = num_patches
169165
self.in_dim = in_dim
170166
new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
171167
self.new_patch_size = new_patch_size
172168

173169
self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
174-
self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
170+
if self.legacy:
171+
self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
172+
else:
173+
self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
175174

176175
def forward(self, x, pixel_pos):
177176
B, C, H, W = x.shape
178177
_assert(H == self.img_size[0],
179178
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
180179
_assert(W == self.img_size[1],
181180
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
182-
x = self.proj(x)
183-
x = self.unfold(x)
184-
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
181+
if self.legacy:
182+
x = self.proj(x)
183+
x = self.unfold(x)
184+
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
185+
else:
186+
x = self.unfold(x)
187+
x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1])
188+
x = self.proj(x)
185189
x = x + pixel_pos
186190
x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
187191
return x
@@ -211,6 +215,7 @@ def __init__(
211215
drop_path_rate=0.,
212216
norm_layer=nn.LayerNorm,
213217
first_stride=4,
218+
legacy=False,
214219
):
215220
super().__init__()
216221
assert global_pool in ('', 'token', 'avg')
@@ -225,6 +230,7 @@ def __init__(
225230
in_chans=in_chans,
226231
in_dim=inner_dim,
227232
stride=first_stride,
233+
legacy=legacy,
228234
)
229235
num_patches = self.pixel_embed.num_patches
230236
self.num_patches = num_patches
@@ -255,6 +261,7 @@ def __init__(
255261
attn_drop=attn_drop_rate,
256262
drop_path=dpr[i],
257263
norm_layer=norm_layer,
264+
legacy=legacy,
258265
))
259266
self.blocks = nn.ModuleList(blocks)
260267
self.norm = norm_layer(embed_dim)
@@ -338,14 +345,38 @@ def forward(self, x):
338345

339346

340347
def checkpoint_filter_fn(state_dict, model):
348+
state_dict.pop('outer_tokens', None)
349+
350+
out_dict = {}
351+
for k, v in state_dict.items():
352+
k = k.replace('outer_pos', 'patch_pos')
353+
k = k.replace('inner_pos', 'pixel_pos')
354+
k = k.replace('patch_embed', 'pixel_embed')
355+
k = k.replace('proj_norm1', 'norm1_proj')
356+
k = k.replace('proj_norm2', 'norm2_proj')
357+
k = k.replace('inner_norm1', 'norm_in')
358+
k = k.replace('inner_attn', 'attn_in')
359+
k = k.replace('inner_norm2', 'norm_mlp_in')
360+
k = k.replace('inner_mlp', 'mlp_in')
361+
k = k.replace('outer_norm1', 'norm_out')
362+
k = k.replace('outer_attn', 'attn_out')
363+
k = k.replace('outer_norm2', 'norm_mlp')
364+
k = k.replace('outer_mlp', 'mlp')
365+
if k == 'pixel_pos':
366+
B, N, C = v.shape
367+
H = W = int(N ** 0.5)
368+
assert H * W == N
369+
v = v.permute(0, 2, 1).reshape(B, C, H, W)
370+
out_dict[k] = v
371+
341372
""" convert patch embedding weight from manual patchify + linear proj to conv"""
342-
if state_dict['patch_pos'].shape != model.patch_pos.shape:
343-
state_dict['patch_pos'] = resample_abs_pos_embed(
344-
state_dict['patch_pos'],
373+
if out_dict['patch_pos'].shape != model.patch_pos.shape:
374+
out_dict['patch_pos'] = resample_abs_pos_embed(
375+
out_dict['patch_pos'],
345376
new_size=model.pixel_embed.grid_size,
346377
num_prefix_tokens=1,
347378
)
348-
return state_dict
379+
return out_dict
349380

350381

351382
def _create_tnt(variant, pretrained=False, **kwargs):
@@ -359,6 +390,30 @@ def _create_tnt(variant, pretrained=False, **kwargs):
359390
return model
360391

361392

393+
def _cfg(url='', **kwargs):
394+
return {
395+
'url': url,
396+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
397+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
398+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
399+
'first_conv': 'pixel_embed.proj', 'classifier': 'head',
400+
**kwargs
401+
}
402+
403+
404+
default_cfgs = {
405+
'tnt_s_patch16_224': _cfg(
406+
# hf_hub_id='timm/',
407+
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
408+
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
409+
),
410+
'tnt_b_patch16_224': _cfg(
411+
# hf_hub_id='timm/',
412+
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
413+
),
414+
}
415+
416+
362417
@register_model
363418
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
364419
model_cfg = dict(

0 commit comments

Comments
 (0)