Skip to content

Commit 3b8ef3f

Browse files
authored
Merge pull request #1890 from Separius/patch-1
use float in resample_abs_pos_embed_nhwc
2 parents 8cb0dda + 40a518c commit 3b8ef3f

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

timm/layers/pos_embed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,14 @@ def resample_abs_pos_embed_nhwc(
6464
if new_size[0] == posemb.shape[-3] and new_size[1] == posemb.shape[-2]:
6565
return posemb
6666

67+
previous_dtype = posemb.dtype
68+
posemb = posemb.float()
6769
# do the interpolation
6870
posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2)
6971
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
70-
posemb = posemb.permute(0, 2, 3, 1)
72+
posemb = posemb.permute(0, 2, 3, 1).to(previous_dtype)
7173

7274
if not torch.jit.is_scripting() and verbose:
7375
_logger.info(f'Resized position embedding: {posemb.shape[-3:-1]} to {new_size}.')
7476

75-
return posemb
77+
return posemb

0 commit comments

Comments
 (0)