Skip to content

Commit 8e4480e

Browse files
committed
Patch and pos embed resample done in float32 always (cast to float and back). Fix #1811
1 parent 150356c commit 8e4480e

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

timm/layers/patch_embed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ def resample_kernel(kernel):
197197
return resampled_kernel.reshape(new_size)
198198

199199
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
200-
return v_resample_kernel(patch_embed)
200+
orig_dtype = patch_embed.dtype
201+
patch_embed = patch_embed.float()
202+
patch_embed = v_resample_kernel(patch_embed)
203+
patch_embed = patch_embed.to(orig_dtype)
204+
return patch_embed
201205

202206

203207
# def divs(n, m=None):

timm/layers/pos_embed.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ def resample_abs_pos_embed(
4040

4141
# do the interpolation
4242
embed_dim = posemb.shape[-1]
43+
orig_dtype = posemb.dtype
44+
posemb = posemb.float() # interpolate needs float32
4345
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
4446
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
4547
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
48+
posemb = posemb.to(orig_dtype)
4649

4750
# add back extra (class, etc) prefix tokens
4851
if posemb_prefix is not None:
@@ -64,12 +67,12 @@ def resample_abs_pos_embed_nhwc(
6467
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
6568
return posemb
6669

67-
previous_dtype = posemb.dtype
70+
orig_dtype = posemb.dtype
6871
posemb = posemb.float()
6972
# do the interpolation
7073
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
7174
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
72-
posemb = posemb.permute(0, 2, 3, 1).to(previous_dtype)
75+
posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype)
7376

7477
if not torch.jit.is_scripting() and verbose:
7578
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')

0 commit comments

Comments
 (0)