@@ -462,55 +462,54 @@ def _interp2d(size):
462462 def _apply_learned_naflex_pos_embed_grid_sample (
463463 self ,
464464 x : torch .Tensor ,
465- naflex_grid_sizes : List [Tuple [int , int ]],
465+ patch_coord : torch .Tensor ,
466+ patch_valid : Optional [torch .Tensor ] = None ,
466467 ):
467468 """ NaFlex 2D position embedding interpolation using F.grid_sample.
468469
469470 Based on proposal by https://github.com/stas-sl
470471 """
471472 device = x .device
472473 B , N , C = x .shape
474+ shapes = patch_coord .max (dim = 1 ).values + 1 # (B, 2) containing [h_i, w_i]
473475
474- def _make_coords (h , w ):
475- _y , _x = torch .meshgrid (
476- torch .arange (h , device = device ),
477- torch .arange (w , device = device ),
478- indexing = 'ij' ,
479- )
480- coord = torch .stack ([_y .flatten (), _x .flatten ()], dim = 1 )
481- return coord
482-
483- coords = torch .zeros (B , N , 2 , dtype = torch .long , device = device )
484- for i , (h , w ) in enumerate (naflex_grid_sizes ):
485- coords_i = _make_coords (h , w ) # (h*w, 2)
486- coords [i , :coords_i .shape [0 ]] = coords_i # pad with zeros past h*w
487- # FIXME should we be masking?
488-
489- shapes = coords .amax (1 ) + 1
490- theta = torch .zeros (B , 2 , 3 , dtype = torch .float32 , device = device )
491476 if self .pos_embed_ar_preserving :
492- L = shapes .amax (1 )
493- grid_max = L .amax ()
494- grid_size = (grid_max , grid_max )
495- theta [:, 0 , 0 ] = grid_size [1 ] / L # scale x
496- theta [:, 1 , 1 ] = grid_size [0 ] / L # scale y
477+ L_i = shapes .amax (dim = 1 ) # (B,) max(h_i, w_i)
478+ L_global = L_i .amax ()
479+ grid_size = (L_global , L_global )
480+ s_x = s_y = L_global / L_i # uniform zoom (B,)
497481 else :
498- grid_size = shapes .amax (0 )
499- theta [:, 0 , 0 ] = grid_size [1 ] / shapes [:, 1 ] # scale x
500- theta [:, 1 , 1 ] = grid_size [0 ] / shapes [:, 0 ] # scale y
482+ grid_size = shapes .amax (dim = 0 )
483+ s_x = grid_size [1 ] / shapes [:, 1 ] # horizontal zoom (B,)
484+ s_y = grid_size [0 ] / shapes [:, 0 ] # vertical zoom (B,)
485+
486+ theta = torch .zeros (B , 2 , 3 , device = device , dtype = torch .float32 )
487+ theta [:, 0 , 0 ] = s_x # scale x
488+ theta [:, 1 , 1 ] = s_y # scale y
501489 theta [:, 0 , 2 ] = theta [:, 0 , 0 ] - 1 # translate x
502490 theta [:, 1 , 2 ] = theta [:, 1 , 1 ] - 1 # translate y
491+
503492 grid = F .affine_grid (theta , (B , C , * grid_size ), align_corners = False )
504493 pos_embed = F .grid_sample (
505494 self .pos_embed .permute (0 , 3 , 1 , 2 ).expand (B , - 1 , - 1 , - 1 ).float (),
506495 grid ,
507496 mode = self .pos_embed_interp_mode ,
508497 align_corners = False ,
509498 padding_mode = 'border' ,
510- ).to (dtype = x .dtype )
499+ ).to (dtype = x .dtype ) # (B, C, H_out, W_out)
500+
501+ # NOTE if we bring in patch_valid, can explicitly mask padding tokens
502+ # more experimentation at train time needed
503+ # lin_idx = patch_coord[..., 0] * grid_size[1] + patch_coord[..., 1] # (B, N)
504+ # pos_flat = pos_embed.flatten(2).transpose(1, 2)
505+ # pos_flat = pos_flat.gather(1, lin_idx.unsqueeze(2).expand(-1, -1, C)) # (B, N, C)
506+ # if patch_valid is not None:
507+ # pos_flat.mul_(patch_valid.unsqueeze(2))
508+ # idx_vec = torch.arange(N, device=device) # (N,)
509+ # x.index_add_(1, idx_vec, pos_flat)
510+
511511 bi = torch .arange (B , device = device ).unsqueeze (1 )
512- # NOTE leave as '+=', do not change to .add_(...)
513- x += pos_embed [bi , :, coords [..., 0 ], coords [..., 1 ]]
512+ x += pos_embed [bi , :, patch_coord [..., 0 ], patch_coord [..., 1 ]] # NOTE leave as '+='
514513
515514 def _apply_learned_pos_embed (
516515 self ,
@@ -605,6 +604,7 @@ def forward(
605604 self ,
606605 x : torch .Tensor ,
607606 patch_coord : Optional [torch .Tensor ] = None ,
607+ patch_valid : Optional [torch .Tensor ] = None ,
608608 ) -> torch .Tensor :
609609 """Forward pass for patch embedding with position encoding.
610610
@@ -676,7 +676,11 @@ def forward(
676676 if self .pos_embed_type == 'learned' :
677677 if naflex_grid_sizes is not None :
678678 if self .pos_embed_use_grid_sample :
679- self ._apply_learned_naflex_pos_embed_grid_sample (x , naflex_grid_sizes = naflex_grid_sizes )
679+ self ._apply_learned_naflex_pos_embed_grid_sample (
680+ x ,
681+ patch_coord = patch_coord ,
682+ patch_valid = patch_valid ,
683+ )
680684 else :
681685 self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
682686 else :
@@ -1146,7 +1150,7 @@ def forward_intermediates(
11461150 mask = create_attention_mask (patch_valid , self .num_prefix_tokens , patches .dtype )
11471151
11481152 # Forward pass through embedding
1149- x = self .embeds (patches , patch_coord = patch_coord )
1153+ x = self .embeds (patches , patch_coord = patch_coord , patch_valid = patch_valid )
11501154 x = self .norm_pre (x )
11511155
11521156 # Forward pass through blocks
@@ -1219,7 +1223,7 @@ def forward_features(
12191223 )
12201224
12211225 # Pass through embedding module with patch coordinate/type support
1222- x = self .embeds (x , patch_coord = patch_coord )
1226+ x = self .embeds (x , patch_coord = patch_coord , patch_valid = patch_valid )
12231227 x = self .norm_pre (x )
12241228 # Apply transformer blocks with masked attention if mask provided
12251229 if attn_mask is not None :
0 commit comments