2525import torch
2626import torch .nn as nn
2727import torch .nn .functional as F
28+ from torch .nn .utils .rnn import pad_sequence
2829
2930from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
3031from timm .layers import (
@@ -89,6 +90,7 @@ class NaFlexVitCfg:
8990 pos_embed_grid_size : Optional [Tuple [int , int ]] = (16 , 16 ) # Grid size for position embedding initialization
9091 pos_embed_interp_mode : str = 'bicubic' # Interpolation mode for position embedding resizing
9192 pos_embed_ar_preserving : bool = False # Whether to preserve aspect ratio during position embedding interpolation
93+ pos_embed_use_grid_sample : bool = False # Whether to use grid_sample for naflex position embedding interpolation
9294
9395 # Image processing
9496 dynamic_img_pad : bool = False # Whether to enable dynamic padding for variable resolution
@@ -221,6 +223,7 @@ def __init__(
221223 pos_embed_grid_size : Optional [Tuple [int , int ]] = (14 , 14 ),
222224 pos_embed_interp_mode : str = 'bicubic' ,
223225 pos_embed_ar_preserving : bool = False ,
226+ pos_embed_use_grid_sample : bool = False ,
224227 input_norm_layer : Optional [Type [nn .Module ]] = None ,
225228 proj_norm_layer : Union [bool , Optional [Type [nn .Module ]]] = None ,
226229 norm_layer : Optional [Type [nn .Module ]] = None ,
@@ -256,6 +259,7 @@ def __init__(
256259 self .num_reg_tokens = reg_tokens
257260 self .pos_embed_interp_mode = pos_embed_interp_mode
258261 self .pos_embed_ar_preserving = pos_embed_ar_preserving
262+ self .pos_embed_use_grid_sample = pos_embed_use_grid_sample
259263 self .patch_size = to_2tuple (patch_size )
260264 self .in_chans = in_chans
261265 self .embed_dim = embed_dim
@@ -438,18 +442,6 @@ def _interp2d(size):
438442 )[:, :, :size [0 ], :size [1 ]].flatten (2 ).transpose (1 , 2 )
439443 return pos_embed_flat .to (dtype = x .dtype )
440444
441- # FIXME leaving alternative code commented here for now for comparisons
442- # pos_embed_cache: Dict[Tuple[int, int], torch.Tensor] = {}
443- # for i, s in enumerate(naflex_grid_sizes):
444- # if s in pos_embed_cache:
445- # pos_embed_flat = pos_embed_cache[s]
446- # else:
447- # pos_embed_flat = _interp(s)
448- # pos_embed_cache[s] = pos_embed_flat
449- #
450- # seq_len = min(x.shape[1], pos_embed_flat.shape[1])
451- # x[i, :seq_len] += pos_embed_flat[0, :seq_len]
452-
453445 # Determine unique grid sizes to avoid duplicate interpolation
454446 size_to_indices : Dict [Tuple [int , int ], List [int ]] = {}
455447 for bi , k in enumerate (naflex_grid_sizes ):
@@ -467,6 +459,57 @@ def _interp2d(size):
467459 pos_embed_flat [:, :seq_len ].expand (len (batch_indices ), - 1 , - 1 )
468460 )
469461
462+ def _apply_learned_naflex_pos_embed_grid_sample (
463+ self ,
464+ x : torch .Tensor ,
465+ naflex_grid_sizes : List [Tuple [int , int ]],
466+ ):
467+ """ NaFlex 2D position embedding interpolation using F.grid_sample.
468+
469+ Based on proposal by https://github.com/stas-sl
470+ """
471+ device = x .device
472+ B , C = x .shape [0 :2 ]
473+
474+ def _make_coords (h , w ):
475+ _y , _x = torch .meshgrid (
476+ torch .arange (h , device = device ),
477+ torch .arange (w , device = device ),
478+ indexing = 'ij' ,
479+ )
480+ coord = torch .stack ([_y .flatten (), _x .flatten ()], dim = 1 )
481+ return coord
482+
483+ coords = pad_sequence (
484+ [_make_coords (h , w ) for h , w in naflex_grid_sizes ],
485+ batch_first = True ,
486+ )
487+ shapes = coords .amax (1 ) + 1
488+ theta = torch .zeros (B , 2 , 3 , dtype = torch .float32 , device = device )
489+ if self .pos_embed_ar_preserving :
490+ shape_max = shapes .amax ()
491+ grid_size = (shape_max , shape_max )
492+ L = shapes .amax (1 )
493+ theta [:, 0 , 0 ] = grid_size [1 ] / L # scale x
494+ theta [:, 1 , 1 ] = grid_size [0 ] / L # scale y
495+ else :
496+ grid_size = shapes .amax (0 )
497+ theta [:, 0 , 0 ] = grid_size [1 ] / shapes [:, 1 ] # scale x
498+ theta [:, 1 , 1 ] = grid_size [0 ] / shapes [:, 0 ] # scale y
499+ theta [:, 0 , 2 ] = theta [:, 0 , 0 ] - 1 # translate x
500+ theta [:, 1 , 2 ] = theta [:, 1 , 1 ] - 1 # translate y
501+ grid = F .affine_grid (theta , (B , C , * grid_size ), align_corners = False )
502+ pos_embed = F .grid_sample (
503+ self .pos_embed .permute (0 , 3 , 1 , 2 ).expand (B , - 1 , - 1 , - 1 ).float (),
504+ grid ,
505+ mode = self .pos_embed_interp_mode ,
506+ align_corners = False ,
507+ padding_mode = 'border' ,
508+ ).to (dtype = x .dtype )
509+ bi = torch .arange (B , device = device ).unsqueeze (1 ).expand (- 1 , coords .shape [1 ])
510+ # NOTE leave as '+=', do not change to .add_(...)
511+ x += pos_embed [bi , :, coords [..., 0 ], coords [..., 1 ]]
512+
470513 def _apply_learned_pos_embed (
471514 self ,
472515 x : torch .Tensor ,
@@ -516,7 +559,7 @@ def _apply_factorized_naflex_pos_embed(
516559 # Handle each batch element separately with its own grid size
517560 orig_h , orig_w = self .pos_embed_y .shape [1 ], self .pos_embed_x .shape [1 ]
518561
519- # bucket samples that share the same (H,W) so we build each grid once
562+ # bucket samples that share the same (H, W) so we build each grid once
520563 size_to_indices : Dict [Tuple [int , int ], List [int ]] = {}
521564 for bi , k in enumerate (naflex_grid_sizes ):
522565 size_to_indices .setdefault (k , []).append (bi )
@@ -630,7 +673,10 @@ def forward(
630673
631674 if self .pos_embed_type == 'learned' :
632675 if naflex_grid_sizes is not None :
633- self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
676+ if self .pos_embed_use_grid_sample :
677+ self ._apply_learned_naflex_pos_embed_grid_sample (x , naflex_grid_sizes = naflex_grid_sizes )
678+ else :
679+ self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
634680 else :
635681 assert grid_size is not None
636682 self ._apply_learned_pos_embed (x , grid_size = grid_size )
@@ -874,6 +920,7 @@ def __init__(
874920 pos_embed_grid_size = cfg .pos_embed_grid_size ,
875921 pos_embed_interp_mode = cfg .pos_embed_interp_mode ,
876922 pos_embed_ar_preserving = cfg .pos_embed_ar_preserving ,
923+ pos_embed_use_grid_sample = cfg .pos_embed_use_grid_sample ,
877924 proj_norm_layer = embed_norm_layer ,
878925 pos_drop_rate = cfg .pos_drop_rate ,
879926 patch_drop_rate = cfg .patch_drop_rate ,
0 commit comments