Skip to content

Commit 22f7c67

Browse files
committed
Merge branch 'alexander-soare-master'
2 parents d4c00d6 + 30b9880 commit 22f7c67

File tree

2 files changed

+32
-14
lines changed

2 files changed

+32
-14
lines changed

timm/models/tnt.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1515
from timm.models.helpers import load_pretrained
1616
from timm.models.layers import Mlp, DropPath, trunc_normal_
17+
from timm.models.layers.helpers import to_2tuple
1718
from timm.models.registry import register_model
19+
from timm.models.vision_transformer import resize_pos_embed
1820

1921

2022
def _cfg(url='', **kwargs):
@@ -118,23 +120,27 @@ class PixelEmbed(nn.Module):
118120
"""
119121
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
120122
super().__init__()
121-
num_patches = (img_size // patch_size) ** 2
123+
img_size = to_2tuple(img_size)
124+
patch_size = to_2tuple(patch_size)
125+
# grid_size property necessary for resizing positional embedding
126+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
127+
num_patches = (self.grid_size[0]) * (self.grid_size[1])
122128
self.img_size = img_size
123129
self.num_patches = num_patches
124130
self.in_dim = in_dim
125-
new_patch_size = math.ceil(patch_size / stride)
131+
new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
126132
self.new_patch_size = new_patch_size
127133

128134
self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
129135
self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
130136

131137
def forward(self, x, pixel_pos):
132138
B, C, H, W = x.shape
133-
assert H == self.img_size and W == self.img_size, \
134-
f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
139+
assert H == self.img_size[0] and W == self.img_size[1], \
140+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
135141
x = self.proj(x)
136142
x = self.unfold(x)
137-
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size)
143+
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
138144
x = x + pixel_pos
139145
x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
140146
return x
@@ -155,15 +161,15 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
155161
num_patches = self.pixel_embed.num_patches
156162
self.num_patches = num_patches
157163
new_patch_size = self.pixel_embed.new_patch_size
158-
num_pixel = new_patch_size ** 2
164+
num_pixel = new_patch_size[0] * new_patch_size[1]
159165

160166
self.norm1_proj = norm_layer(num_pixel * in_dim)
161167
self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
162168
self.norm2_proj = norm_layer(embed_dim)
163169

164170
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
165171
self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
166-
self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size))
172+
self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1]))
167173
self.pos_drop = nn.Dropout(p=drop_rate)
168174

169175
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@@ -224,14 +230,23 @@ def forward(self, x):
224230
return x
225231

226232

233+
def checkpoint_filter_fn(state_dict, model):
234+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
235+
if state_dict['patch_pos'].shape != model.patch_pos.shape:
236+
state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'],
237+
model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size)
238+
return state_dict
239+
240+
227241
@register_model
228242
def tnt_s_patch16_224(pretrained=False, **kwargs):
229243
model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
230244
qkv_bias=False, **kwargs)
231245
model.default_cfg = default_cfgs['tnt_s_patch16_224']
232246
if pretrained:
233247
load_pretrained(
234-
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
248+
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3),
249+
filter_fn=checkpoint_filter_fn)
235250
return model
236251

237252

timm/models/vision_transformer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = Fa
352352
nn.init.ones_(m.weight)
353353

354354

355-
def resize_pos_embed(posemb, posemb_new, num_tokens=1):
355+
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
356356
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
357357
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
358358
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
@@ -363,11 +363,13 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1):
363363
else:
364364
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
365365
gs_old = int(math.sqrt(len(posemb_grid)))
366-
gs_new = int(math.sqrt(ntok_new))
367-
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
366+
if not len(gs_new): # backwards compatibility
367+
gs_new = [int(math.sqrt(ntok_new))] * 2
368+
assert len(gs_new) >= 2
369+
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
368370
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
369-
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
370-
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
371+
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
372+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
371373
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
372374
return posemb
373375

@@ -385,7 +387,8 @@ def checkpoint_filter_fn(state_dict, model):
385387
v = v.reshape(O, -1, H, W)
386388
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
387389
# To resize pos embedding when using model at different size from pretrained weights
388-
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
390+
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1),
391+
model.patch_embed.grid_size)
389392
out_dict[k] = v
390393
return out_dict
391394

0 commit comments

Comments
 (0)