@@ -165,6 +165,13 @@ def batch_patchify(
165165 return patches , (nh , nw )
166166
167167
168+ def calculate_naflex_grid_sizes (_coord : torch .Tensor ):
169+ # Calculate the appropriate grid size from coords
170+ max_y = _coord [:, :, 0 ].amax (dim = 1 ) + 1
171+ max_x = _coord [:, :, 1 ].amax (dim = 1 ) + 1
172+ return [(int (h .item ()), int (w .item ())) for h , w in zip (max_y , max_x )]
173+
174+
168175@register_notrace_module
169176class NaFlexEmbeds (nn .Module ):
170177 """NaFlex Embedding module for Vision Transformers.
@@ -407,18 +414,19 @@ def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
407414 def _apply_learned_naflex_pos_embed (
408415 self ,
409416 x : torch .Tensor ,
410- naflex_grid_sizes : List [ Tuple [ int , int ]] ,
417+ patch_coord : torch . Tensor ,
411418 ) -> None :
412419 """Apply learned position embeddings to NaFlex batch in-place.
413420
414- Interpolates learned position embeddings for each sample in the batch
421+ Interpolates learned 2D position embeddings for each sample in the batch
415422 based on their individual grid sizes.
416423
417424 Args:
418- x: Input tensor to add position embeddings to
419- naflex_grid_sizes: List of (height, width) grid sizes for each batch element
425+ x: Input tensor to add position embeddings to [B, N, C]
426+ patch_coord: Patch coordinates [B, N, 2] with (y, x) values
420427 """
421- # Handle each batch element separately with its own grid size
428+ # Calculate grid sizes from patch coordinates
429+ naflex_grid_sizes = calculate_naflex_grid_sizes (patch_coord )
422430 orig_h , orig_w = self .pos_embed .shape [1 :3 ]
423431 pos_embed_nchw = self .pos_embed .permute (0 , 3 , 1 , 2 ).float () # B,C,H,W
424432
@@ -463,33 +471,37 @@ def _apply_learned_naflex_pos_embed_grid_sample(
463471 self ,
464472 x : torch .Tensor ,
465473 patch_coord : torch .Tensor ,
466- patch_valid : Optional [torch .Tensor ] = None ,
467- ):
468- """ NaFlex 2D position embedding interpolation using F.grid_sample.
474+ ) -> None :
475+ """Apply learned position embeddings to NaFlex batch using grid_sample.
476+
477+ Uses F.grid_sample for efficient interpolation of learned 2D position embeddings
478+ based on patch coordinates. Based on proposal by https://github.com/stas-sl
469479
470- Based on proposal by https://github.com/stas-sl
480+ Args:
481+ x: Input tensor to add position embeddings to [B, N, C]
482+ patch_coord: Patch coordinates [B, N, 2] with (y, x) values
471483 """
472484 device = x .device
473485 B , N , C = x .shape
474486 shapes = patch_coord .max (dim = 1 ).values + 1 # (B, 2) containing [h_i, w_i]
475487
476488 if self .pos_embed_ar_preserving :
477- L_i = shapes .amax (dim = 1 ) # (B,) max(h_i, w_i)
489+ L_i = shapes .amax (dim = 1 ) # (B,) max(h_i, w_i)
478490 L_global = L_i .amax ()
479- grid_size = ( L_global , L_global )
480- s_x = s_y = L_global / L_i # uniform zoom (B,)
491+ grid_size_y = grid_size_x = L_global
492+ scale_x = scale_y = L_global / L_i # uniform zoom (B,)
481493 else :
482- grid_size = shapes .amax (dim = 0 )
483- s_x = grid_size [ 1 ] / shapes [:, 1 ] # horizontal zoom (B,)
484- s_y = grid_size [ 0 ] / shapes [:, 0 ] # vertical zoom (B,)
494+ grid_size_y , grid_size_x = shapes .amax (dim = 0 ) # (2, )
495+ scale_y = grid_size_y / shapes [:, 0 ] # vertical zoom (B,)
496+ scale_x = grid_size_x / shapes [:, 1 ] # horizontal zoom (B,)
485497
486498 theta = torch .zeros (B , 2 , 3 , device = device , dtype = torch .float32 )
487- theta [:, 0 , 0 ] = s_x # scale x
488- theta [:, 1 , 1 ] = s_y # scale y
489- theta [:, 0 , 2 ] = theta [:, 0 , 0 ] - 1 # translate x
490- theta [:, 1 , 2 ] = theta [:, 1 , 1 ] - 1 # translate y
499+ theta [:, 0 , 0 ] = scale_x
500+ theta [:, 1 , 1 ] = scale_y
501+ theta [:, 0 , 2 ] = scale_x - 1 # translate x
502+ theta [:, 1 , 2 ] = scale_y - 1 # translate y
491503
492- grid = F .affine_grid (theta , (B , C , * grid_size ), align_corners = False )
504+ grid = F .affine_grid (theta , (B , C , grid_size_y , grid_size_x ), align_corners = False )
493505 pos_embed = F .grid_sample (
494506 self .pos_embed .permute (0 , 3 , 1 , 2 ).expand (B , - 1 , - 1 , - 1 ).float (),
495507 grid ,
@@ -498,16 +510,6 @@ def _apply_learned_naflex_pos_embed_grid_sample(
498510 padding_mode = 'border' ,
499511 ).to (dtype = x .dtype ) # (B, C, H_out, W_out)
500512
501- # NOTE if we bring in patch_valid, can explicitly mask padding tokens
502- # more experimentation at train time needed
503- # lin_idx = patch_coord[..., 0] * grid_size[1] + patch_coord[..., 1] # (B, N)
504- # pos_flat = pos_embed.flatten(2).transpose(1, 2)
505- # pos_flat = pos_flat.gather(1, lin_idx.unsqueeze(2).expand(-1, -1, C)) # (B, N, C)
506- # if patch_valid is not None:
507- # pos_flat.mul_(patch_valid.unsqueeze(2))
508- # idx_vec = torch.arange(N, device=device) # (N,)
509- # x.index_add_(1, idx_vec, pos_flat)
510-
511513 bi = torch .arange (B , device = device ).unsqueeze (1 )
512514 x += pos_embed [bi , :, patch_coord [..., 0 ], patch_coord [..., 1 ]] # NOTE leave as '+='
513515
@@ -516,45 +518,48 @@ def _apply_learned_pos_embed(
516518 x : torch .Tensor ,
517519 grid_size : List [int ],
518520 ) -> None :
519- """Apply learned position embeddings to standard batch in-place.
521+ """Apply learned position embeddings to standard 2D batch in-place.
520522
521- Interpolates learned position embeddings to match the specified grid size.
523+ Interpolates learned 2D position embeddings to match the specified grid size.
522524
523525 Args:
524- x: Input tensor to add position embeddings to
526+ x: Input tensor to add position embeddings to [B, H*W, C]
525527 grid_size: Target grid size as [height, width]
526528 """
527529 orig_h , orig_w = self .pos_embed .shape [1 :3 ]
528- if grid_size [0 ] == orig_h or grid_size [1 ] == orig_w :
530+ if grid_size [0 ] == orig_h and grid_size [1 ] == orig_w :
529531 # No resize needed, just flatten
530532 pos_embed_flat = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
531533 else :
532534 # Resize if needed - directly using F.interpolate
535+ _interp_size = to_2tuple (max (grid_size )) if self .pos_embed_ar_preserving else grid_size
533536 pos_embed_flat = F .interpolate (
534537 self .pos_embed .permute (0 , 3 , 1 , 2 ).float (), # B,C,H,W
535- size = grid_size ,
538+ size = _interp_size ,
536539 mode = self .pos_embed_interp_mode ,
537540 align_corners = False ,
538541 antialias = True ,
539- ).flatten (2 ).transpose (1 , 2 )
542+ )[:, :, : grid_size [ 0 ], : grid_size [ 1 ]] .flatten (2 ).transpose (1 , 2 )
540543 pos_embed_flat = pos_embed_flat .to (dtype = x .dtype )
541544
542545 x .add_ (pos_embed_flat )
543546
544547 def _apply_factorized_naflex_pos_embed (
545548 self ,
546549 x : torch .Tensor ,
547- naflex_grid_sizes : List [ Tuple [ int , int ]] ,
550+ patch_coord : torch . Tensor ,
548551 ) -> None :
549552 """Apply factorized position embeddings to NaFlex batch in-place.
550553
551554 Uses separate Y and X position embedding tables that are interpolated
552555 and combined for each sample's grid size.
553556
554557 Args:
555- x: Input tensor to add position embeddings to
556- naflex_grid_sizes: List of (height, width) grid sizes for each batch element
558+ x: Input tensor to add position embeddings to [B, N, C]
559+ patch_coord: Patch coordinates [B, N, 2] with (y, x) values
557560 """
561+ # Calculate grid sizes from patch coordinates
562+ naflex_grid_sizes = calculate_naflex_grid_sizes (patch_coord )
558563 assert len (naflex_grid_sizes ) == x .size (0 ) # one (H,W) per sample
559564
560565 # Handle each batch element separately with its own grid size
@@ -600,11 +605,99 @@ def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.T
600605 pos [:, :seq_len ].expand (len (batch_indices ), - 1 , - 1 )
601606 )
602607
608+ def _apply_factorized_naflex_pos_embed_grid_sample (
609+ self ,
610+ x : torch .Tensor ,
611+ patch_coord : torch .Tensor ,
612+ ) -> None :
613+ """Apply factorized position embeddings to NaFlex batch using grid_sample.
614+
615+ Uses F.grid_sample for efficient interpolation of separate Y and X position
616+ embedding tables based on patch coordinates. Based on proposal by https://github.com/stas-sl
617+
618+ Args:
619+ x: Input tensor to add position embeddings to [B, N, C]
620+ patch_coord: Patch coordinates [B, N, 2] with (y, x) values
621+ """
622+ device = x .device
623+ B , _ , C = x .shape
624+ shapes = patch_coord .amax (dim = 1 ) + 1
625+
626+ if self .pos_embed_ar_preserving :
627+ # Aspect ratio preserving mode: use square grid with uniform scaling
628+ L_i = shapes .amax (dim = 1 ) # (B,) max(h_i, w_i)
629+ L_global = L_i .amax ()
630+ grid_size_y = grid_size_x = L_global
631+ scale_x = scale_y = L_global / L_i # uniform zoom (B,)
632+ else :
633+ # Standard mode: different scaling for x and y
634+ grid_size_y , grid_size_x = shapes .amax (0 )
635+ scale_x = grid_size_x / shapes [:, 1 ] # horizontal zoom (B,)
636+ scale_y = grid_size_y / shapes [:, 0 ] # vertical zoom (B,)
637+
638+ def _interp1d (table : torch .Tensor , scale : torch .Tensor , out_length : torch .Tensor ) -> torch .Tensor :
639+ pe = table .permute (0 , 2 , 1 ).unsqueeze (2 ).expand (B , - 1 , - 1 , - 1 ).float () # (1, L, C) -> (B, C, 1, L)
640+ theta = torch .zeros (B , 2 , 3 , device = x .device )
641+ theta [:, 0 , 0 ] = scale
642+ theta [:, 0 , 2 ] = scale - 1
643+ theta [:, 1 , 1 ] = 1
644+ grid = F .affine_grid (theta , (B , C , 1 , out_length ), align_corners = False )
645+ pe = F .grid_sample (pe , grid , mode = 'bilinear' , align_corners = False , padding_mode = 'border' )
646+ return pe .to (x .dtype )
647+
648+ # Interpolate along each axis
649+ pe_x = _interp1d (self .pos_embed_x , scale = scale_x , out_length = grid_size_x )
650+ pe_y = _interp1d (self .pos_embed_y , scale = scale_y , out_length = grid_size_y )
651+
652+ bi = torch .arange (B , device = device ).unsqueeze (1 )
653+ x += pe_x [bi , :, 0 , patch_coord [..., 1 ]] + pe_y [bi , :, 0 , patch_coord [..., 0 ]]
654+
655+ def _apply_factorized_pos_embed (
656+ self ,
657+ x : torch .Tensor ,
658+ grid_size : List [int ],
659+ ) -> None :
660+ """Apply factorized position embeddings to standard 2D batch in-place.
661+
662+ Uses separate Y and X position embedding tables that are interpolated
663+ and combined for the specified grid size.
664+
665+ Args:
666+ x: Input tensor to add position embeddings to [B, H*W, C]
667+ grid_size: Target grid size as [height, width]
668+ """
669+ orig_h , orig_w = self .pos_embed_y .shape [1 ], self .pos_embed_x .shape [1 ]
670+ target_h , target_w = grid_size
671+
672+ if self .pos_embed_ar_preserving :
673+ len_y = len_x = max (target_h , target_w )
674+ else :
675+ len_y , len_x = target_h , target_w
676+
677+ def _interp1d (table : torch .Tensor , new_length : int , orig_length : int ) -> torch .Tensor :
678+ if new_length == orig_length :
679+ return table .to (dtype = x .dtype )
680+ return F .interpolate (
681+ table .permute (0 , 2 , 1 ).float (), # (1,L,C) -> (1,C,L)
682+ size = new_length ,
683+ mode = 'linear' ,
684+ align_corners = False ,
685+ ).permute (0 , 2 , 1 ).to (dtype = x .dtype ) # (1,L,C)
686+
687+ # Interpolate embeddings
688+ pe_y = _interp1d (self .pos_embed_y , len_y , orig_h )[:, :target_h ] # (1,H,C)
689+ pe_x = _interp1d (self .pos_embed_x , len_x , orig_w )[:, :target_w ] # (1,W,C)
690+
691+ # Broadcast, add and flatten to sequence layout (row major)
692+ pos_embed = pe_y .unsqueeze (2 ) + pe_x .unsqueeze (1 ) # (1, H, W, C)
693+ pos_embed_flat = pos_embed .flatten (1 , 2 ) # (1, H*W, C)
694+
695+ x .add_ (pos_embed_flat )
696+
603697 def forward (
604698 self ,
605699 x : torch .Tensor ,
606700 patch_coord : Optional [torch .Tensor ] = None ,
607- patch_valid : Optional [torch .Tensor ] = None ,
608701 ) -> torch .Tensor :
609702 """Forward pass for patch embedding with position encoding.
610703
@@ -619,24 +712,18 @@ def forward(
619712 Embedded tensor with position encoding and class/register tokens.
620713 Shape: [B, num_prefix_tokens + N, embed_dim]
621714 """
622- # Apply patch embedding
623- naflex_grid_sizes : Optional [List [Tuple [int , int ]]] = None
624715 grid_size : Optional [List [int ]] = None
625-
626716 B = x .shape [0 ]
627717 if self .is_linear :
628718 # Linear embedding path, works with NaFlex mode or standard 2D mode
629- if patch_coord is not None :
719+ if patch_coord is None :
720+ # Standard 2D (B, C, H, W) mode
721+ _assert (x .ndim == 4 , 'Expecting 2D image input with input ndim == 4' )
722+ x , grid_size = batch_patchify (x , self .patch_size , pad = self .dynamic_img_pad )
723+ else :
630724 # Pre-patchified NaFlex mode
631725 # Variable patch size mode: [B, N, Ph, Pw, C], normal mode: [B, N, P*P*C]
632726 _assert (x .ndim == 5 or x .ndim == 3 , 'Expecting patchified input with ndim == 3 or 5.' )
633- # Calculate the appropriate grid size from coords
634- max_y = patch_coord [:, :, 0 ].max (dim = 1 )[0 ] + 1
635- max_x = patch_coord [:, :, 1 ].max (dim = 1 )[0 ] + 1
636- naflex_grid_sizes = [(int (h .item ()), int (w .item ())) for h , w in zip (max_y , max_x )]
637- else :
638- _assert (x .ndim == 4 , 'Expecting 2D image input with input ndim == 4' )
639- x , grid_size = batch_patchify (x , self .patch_size , pad = self .dynamic_img_pad )
640727
641728 # Handle variable patch size projection
642729 if self .enable_patch_interpolator and x .ndim == 5 :
@@ -674,21 +761,25 @@ def forward(
674761 x = self .norm (x )
675762
676763 if self .pos_embed_type == 'learned' :
677- if naflex_grid_sizes is not None :
764+ if grid_size is not None :
765+ # Standard 2D mode
766+ self ._apply_learned_pos_embed (x , grid_size = grid_size )
767+ else :
768+ # NaFlex mode
678769 if self .pos_embed_use_grid_sample :
679- self ._apply_learned_naflex_pos_embed_grid_sample (
680- x ,
681- patch_coord = patch_coord ,
682- patch_valid = patch_valid ,
683- )
770+ self ._apply_learned_naflex_pos_embed_grid_sample (x , patch_coord = patch_coord )
684771 else :
685- self ._apply_learned_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
686- else :
687- assert grid_size is not None
688- self ._apply_learned_pos_embed (x , grid_size = grid_size )
772+ self ._apply_learned_naflex_pos_embed (x , patch_coord = patch_coord )
689773 elif self .pos_embed_type == 'factorized' :
690- if naflex_grid_sizes is not None :
691- self ._apply_factorized_naflex_pos_embed (x , naflex_grid_sizes = naflex_grid_sizes )
774+ if grid_size is not None :
775+ # Standard 2D mode
776+ self ._apply_factorized_pos_embed (x , grid_size = grid_size )
777+ else :
778+ # NaFlex mode
779+ if self .pos_embed_use_grid_sample :
780+ self ._apply_factorized_naflex_pos_embed_grid_sample (x , patch_coord = patch_coord )
781+ else :
782+ self ._apply_factorized_naflex_pos_embed (x , patch_coord = patch_coord )
692783 elif self .pos_embed_type == 'rope' :
693784 assert False , "ROPE not yet implemented"
694785
@@ -1150,7 +1241,7 @@ def forward_intermediates(
11501241 mask = create_attention_mask (patch_valid , self .num_prefix_tokens , patches .dtype )
11511242
11521243 # Forward pass through embedding
1153- x = self .embeds (patches , patch_coord = patch_coord , patch_valid = patch_valid )
1244+ x = self .embeds (patches , patch_coord = patch_coord )
11541245 x = self .norm_pre (x )
11551246
11561247 # Forward pass through blocks
@@ -1223,7 +1314,7 @@ def forward_features(
12231314 )
12241315
12251316 # Pass through embedding module with patch coordinate/type support
1226- x = self .embeds (x , patch_coord = patch_coord , patch_valid = patch_valid )
1317+ x = self .embeds (x , patch_coord = patch_coord )
12271318 x = self .norm_pre (x )
12281319 # Apply transformer blocks with masked attention if mask provided
12291320 if attn_mask is not None :
0 commit comments