Skip to content

Commit 16d0b26

Browse files
committed
Fix torchscript issue with legacy tnt
1 parent 74ad32a commit 16d0b26

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

timm/models/tnt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __init__(
103103
if self.legacy:
104104
self.norm1_proj = norm_layer(dim)
105105
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
106+
self.norm2_proj = None
106107
else:
107108
self.norm1_proj = norm_layer(dim * num_pixel)
108109
self.proj = nn.Linear(dim * num_pixel, dim_out, bias=False)
@@ -135,7 +136,7 @@ def forward(self, pixel_embed, patch_embed):
135136
pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
136137
# outer
137138
B, N, C = patch_embed.size()
138-
if self.legacy:
139+
if self.norm2_proj is None:
139140
patch_embed = torch.cat([
140141
patch_embed[:, 0:1],
141142
patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)),

0 commit comments

Comments
 (0)