@@ -40,9 +40,12 @@ def resample_abs_pos_embed(
4040
4141 # do the interpolation
4242 embed_dim = posemb .shape [- 1 ]
43+ orig_dtype = posemb .dtype
44+ posemb = posemb .float () # interpolate needs float32
4345 posemb = posemb .reshape (1 , old_size [0 ], old_size [1 ], - 1 ).permute (0 , 3 , 1 , 2 )
4446 posemb = F .interpolate (posemb , size = new_size , mode = interpolation , antialias = antialias )
4547 posemb = posemb .permute (0 , 2 , 3 , 1 ).reshape (1 , - 1 , embed_dim )
48+ posemb = posemb .to (orig_dtype )
4649
4750 # add back extra (class, etc) prefix tokens
4851 if posemb_prefix is not None :
@@ -64,12 +67,12 @@ def resample_abs_pos_embed_nhwc(
6467 if new_size [0 ] == posemb .shape [- 3 ] and new_size [1 ] == posemb .shape [- 2 ]:
6568 return posemb
6669
67- previous_dtype = posemb .dtype
70+ orig_dtype = posemb .dtype
6871 posemb = posemb .float ()
6972 # do the interpolation
7073 posemb = posemb .reshape (1 , posemb .shape [- 3 ], posemb .shape [- 2 ], posemb .shape [- 1 ]).permute (0 , 3 , 1 , 2 )
7174 posemb = F .interpolate (posemb , size = new_size , mode = interpolation , antialias = antialias )
72- posemb = posemb .permute (0 , 2 , 3 , 1 ).to (previous_dtype )
75+ posemb = posemb .permute (0 , 2 , 3 , 1 ).to (orig_dtype )
7376
7477 if not torch .jit .is_scripting () and verbose :
7578 _logger .info (f'Resized position embedding: { posemb .shape [- 3 :- 1 ]} to { new_size } .' )
0 commit comments