Skip to content

Commit 74ad32a

Browse files
committed
Updated tnt model weights on hub, add back legacy model in case bwd compat
1 parent 69b1fbc commit 74ad32a

File tree

1 file changed

+72
-44
lines changed

1 file changed

+72
-44
lines changed

timm/models/tnt.py

Lines changed: 72 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
from ._manipulate import checkpoint
2323
from ._registry import generate_default_cfgs, register_model
2424

25-
2625
__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
2726

2827

2928
class Attention(nn.Module):
3029
""" Multi-Head Attention
3130
"""
31+
3232
def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
3333
super().__init__()
3434
self.hidden_dim = hidden_dim
@@ -46,7 +46,7 @@ def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., p
4646
def forward(self, x):
4747
B, N, C = x.shape
4848
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
49-
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
49+
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
5050
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
5151

5252
attn = (q @ k.transpose(-2, -1)) * self.scale
@@ -62,6 +62,7 @@ def forward(self, x):
6262
class Block(nn.Module):
6363
""" TNT Block
6464
"""
65+
6566
def __init__(
6667
self,
6768
dim,
@@ -89,7 +90,7 @@ def __init__(
8990
attn_drop=attn_drop,
9091
proj_drop=proj_drop,
9192
)
92-
93+
9394
self.norm_mlp_in = norm_layer(dim)
9495
self.mlp_in = Mlp(
9596
in_features=dim,
@@ -118,7 +119,7 @@ def __init__(
118119
proj_drop=proj_drop,
119120
)
120121
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
121-
122+
122123
self.norm_mlp = norm_layer(dim_out)
123124
self.mlp = Mlp(
124125
in_features=dim_out,
@@ -136,13 +137,13 @@ def forward(self, pixel_embed, patch_embed):
136137
B, N, C = patch_embed.size()
137138
if self.legacy:
138139
patch_embed = torch.cat([
139-
patch_embed[:, 0:1], patch_embed[:, 1:] + \
140-
self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
140+
patch_embed[:, 0:1],
141+
patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),
141142
], dim=1)
142143
else:
143144
patch_embed = torch.cat([
144-
patch_embed[:, 0:1], patch_embed[:, 1:] + \
145-
self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
145+
patch_embed[:, 0:1],
146+
patch_embed[:, 1:] + self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, N - 1, -1)))),
146147
], dim=1)
147148
patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
148149
patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
@@ -152,7 +153,16 @@ def forward(self, pixel_embed, patch_embed):
152153
class PixelEmbed(nn.Module):
153154
""" Image to Pixel Embedding
154155
"""
155-
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4, legacy=False):
156+
157+
def __init__(
158+
self,
159+
img_size=224,
160+
patch_size=16,
161+
in_chans=3,
162+
in_dim=48,
163+
stride=4,
164+
legacy=False,
165+
):
156166
super().__init__()
157167
img_size = to_2tuple(img_size)
158168
patch_size = to_2tuple(patch_size)
@@ -184,14 +194,17 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
184194

185195
def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
186196
B, C, H, W = x.shape
187-
_assert(H == self.img_size[0],
197+
_assert(
198+
H == self.img_size[0],
188199
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
189-
_assert(W == self.img_size[1],
200+
_assert(
201+
W == self.img_size[1],
190202
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
191203
if self.legacy:
192204
x = self.proj(x)
193205
x = self.unfold(x)
194-
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
206+
x = x.transpose(1, 2).reshape(
207+
B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
195208
else:
196209
x = self.unfold(x)
197210
x = x.transpose(1, 2).reshape(B * self.num_patches, C, self.patch_size[0], self.patch_size[1])
@@ -204,6 +217,7 @@ def forward(self, x: torch.Tensor, pixel_pos: torch.Tensor) -> torch.Tensor:
204217
class TNT(nn.Module):
205218
""" Transformer in Transformer - https://arxiv.org/abs/2103.00112
206219
"""
220+
207221
def __init__(
208222
self,
209223
img_size=224,
@@ -248,7 +262,7 @@ def __init__(
248262
self.num_patches = num_patches
249263
new_patch_size = self.pixel_embed.new_patch_size
250264
num_pixel = new_patch_size[0] * new_patch_size[1]
251-
265+
252266
self.norm1_proj = norm_layer(num_pixel * inner_dim)
253267
self.proj = nn.Linear(num_pixel * inner_dim, embed_dim)
254268
self.norm2_proj = norm_layer(embed_dim)
@@ -278,7 +292,7 @@ def __init__(
278292
self.blocks = nn.ModuleList(blocks)
279293
self.feature_info = [
280294
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
281-
295+
282296
self.norm = norm_layer(embed_dim)
283297
self.head_drop = nn.Dropout(drop_rate)
284298
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@@ -359,7 +373,7 @@ def forward_intermediates(
359373
B, _, height, width = x.shape
360374

361375
pixel_embed = self.pixel_embed(x, self.pixel_pos)
362-
376+
363377
patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
364378
patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
365379
patch_embed = patch_embed + self.patch_pos
@@ -381,7 +395,7 @@ def forward_intermediates(
381395
# split prefix (e.g. class, distill) and spatial feature tokens
382396
prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
383397
intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
384-
398+
385399
if reshape:
386400
# reshape to BCHW output format
387401
H, W = self.pixel_embed.dynamic_feat_size((height, width))
@@ -416,7 +430,7 @@ def prune_intermediate_layers(
416430
def forward_features(self, x):
417431
B = x.shape[0]
418432
pixel_embed = self.pixel_embed(x, self.pixel_pos)
419-
433+
420434
patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
421435
patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
422436
patch_embed = patch_embed + self.patch_pos
@@ -458,42 +472,47 @@ def _cfg(url='', **kwargs):
458472

459473

460474
default_cfgs = generate_default_cfgs({
475+
'tnt_s_legacy_patch16_224.in1k': _cfg(
476+
hf_hub_id='timm/',
477+
#url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
478+
),
461479
'tnt_s_patch16_224.in1k': _cfg(
462-
# hf_hub_id='timm/',
463-
# url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
464-
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
480+
hf_hub_id='timm/',
481+
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_s_81.5.pth.tar',
465482
),
466483
'tnt_b_patch16_224.in1k': _cfg(
467-
# hf_hub_id='timm/',
468-
url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
484+
hf_hub_id='timm/',
485+
#url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/tnt/tnt_b_82.9.pth.tar',
469486
),
470487
})
471488

472489

473490
def checkpoint_filter_fn(state_dict, model):
474491
state_dict.pop('outer_tokens', None)
475-
476-
out_dict = {}
477-
for k, v in state_dict.items():
478-
k = k.replace('outer_pos', 'patch_pos')
479-
k = k.replace('inner_pos', 'pixel_pos')
480-
k = k.replace('patch_embed', 'pixel_embed')
481-
k = k.replace('proj_norm1', 'norm1_proj')
482-
k = k.replace('proj_norm2', 'norm2_proj')
483-
k = k.replace('inner_norm1', 'norm_in')
484-
k = k.replace('inner_attn', 'attn_in')
485-
k = k.replace('inner_norm2', 'norm_mlp_in')
486-
k = k.replace('inner_mlp', 'mlp_in')
487-
k = k.replace('outer_norm1', 'norm_out')
488-
k = k.replace('outer_attn', 'attn_out')
489-
k = k.replace('outer_norm2', 'norm_mlp')
490-
k = k.replace('outer_mlp', 'mlp')
491-
if k == 'pixel_pos' and model.pixel_embed.legacy == False:
492-
B, N, C = v.shape
493-
H = W = int(N ** 0.5)
494-
assert H * W == N
495-
v = v.permute(0, 2, 1).reshape(B, C, H, W)
496-
out_dict[k] = v
492+
if 'patch_pos' in state_dict:
493+
out_dict = state_dict
494+
else:
495+
out_dict = {}
496+
for k, v in state_dict.items():
497+
k = k.replace('outer_pos', 'patch_pos')
498+
k = k.replace('inner_pos', 'pixel_pos')
499+
k = k.replace('patch_embed', 'pixel_embed')
500+
k = k.replace('proj_norm1', 'norm1_proj')
501+
k = k.replace('proj_norm2', 'norm2_proj')
502+
k = k.replace('inner_norm1', 'norm_in')
503+
k = k.replace('inner_attn', 'attn_in')
504+
k = k.replace('inner_norm2', 'norm_mlp_in')
505+
k = k.replace('inner_mlp', 'mlp_in')
506+
k = k.replace('outer_norm1', 'norm_out')
507+
k = k.replace('outer_attn', 'attn_out')
508+
k = k.replace('outer_norm2', 'norm_mlp')
509+
k = k.replace('outer_mlp', 'mlp')
510+
if k == 'pixel_pos' and model.pixel_embed.legacy == False:
511+
B, N, C = v.shape
512+
H = W = int(N ** 0.5)
513+
assert H * W == N
514+
v = v.permute(0, 2, 1).reshape(B, C, H, W)
515+
out_dict[k] = v
497516

498517
""" convert patch embedding weight from manual patchify + linear proj to conv"""
499518
if out_dict['patch_pos'].shape != model.patch_pos.shape:
@@ -515,6 +534,15 @@ def _create_tnt(variant, pretrained=False, **kwargs):
515534
return model
516535

517536

537+
@register_model
538+
def tnt_s_legacy_patch16_224(pretrained=False, **kwargs) -> TNT:
539+
model_cfg = dict(
540+
patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
541+
qkv_bias=False, legacy=True)
542+
model = _create_tnt('tnt_s_legacy_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
543+
return model
544+
545+
518546
@register_model
519547
def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
520548
model_cfg = dict(

0 commit comments

Comments
 (0)