@@ -234,10 +234,40 @@ def apply_rot_embed_cat(x: torch.Tensor, emb):
234234 return x * cos_emb + rot (x ) * sin_emb
235235
236236
237- def apply_keep_indices_nlc (x , pos_embed , keep_indices ):
238- pos_embed = pos_embed .unsqueeze (0 ).expand (x .shape [0 ], - 1 , - 1 )
239- pos_embed = pos_embed .gather (1 , keep_indices .unsqueeze (- 1 ).expand (- 1 , - 1 , pos_embed .shape [- 1 ]))
240- return pos_embed
237+ def apply_keep_indices_nlc (
238+ x : torch .Tensor ,
239+ pos_embed : torch .Tensor ,
240+ keep_indices : torch .Tensor ,
241+ pos_embed_has_batch : bool = False ,
242+ ) -> torch .Tensor :
243+ """ Apply keep indices to different ROPE shapes
244+ Expected shapes:
245+ * pos_embed shape [seq_len, pos_embed_dim] → output [batch_size, seq_len, pos_embed_dim]
246+ * pos_embed shape [num_heads, seq_len, pos_embed_dim] → output [batch_size, num_heads, seq_len, pos_embed_dim]
247+ * pos_embed shape [depth, num_heads, seq_len, pos_embed_dim] → output [batch_size, depth, num_heads, seq_len, pos_embed_dim]
248+
249+ And all of the above with leading batch dimension already present if `pos_embed_has_batch == True`
250+
251+ """
252+ if pos_embed_has_batch :
253+ # Pos embed already includes batch dim
254+ _assert (pos_embed .ndim >= 3 , 'Incorrect number of dimensions' ) # At least [batch, seq_len, pos_embed_dim]
255+ else :
256+ # Add batch dimension and expand to batch size
257+ _assert (pos_embed .ndim >= 2 , 'Incorrect number of dimensions' ) # At least [seq_len, pos_embed_dim]
258+ expand_shape = (x .shape [0 ],) + (- 1 ,) * pos_embed .ndim
259+ pos_embed = pos_embed .unsqueeze (0 ).expand (expand_shape )
260+
261+ # Reshape keep_indices to add singleton dims
262+ keep_shape = (keep_indices .shape [0 ],) + (1 ,) * (pos_embed .ndim - 3 ) + (keep_indices .shape [1 ], 1 )
263+ keep_indices = keep_indices .view (keep_shape )
264+
265+ # Expand all dims to match position embedding except the gather dim (second-last)
266+ keep_expand = list (pos_embed .shape )
267+ keep_expand [- 2 ] = - 1
268+ keep_indices = keep_indices .expand (keep_expand )
269+
270+ return pos_embed .gather (- 2 , keep_indices )
241271
242272
243273def build_rotary_pos_embed (
@@ -484,6 +514,59 @@ def get_embed(self, shape: Optional[List[int]] = None):
484514 else :
485515 assert False , "get_embed() requires pre-computed pos embed or valid shape w/ pre-computed bands"
486516
517+ def get_batch_embeds (
518+ self ,
519+ shapes : List [Tuple [int , int ]],
520+ seq_len : Optional [int ] = None ,
521+ ) -> Union [torch .Tensor , List [torch .Tensor ]]:
522+ """Generate ROPE embeddings for multiple grid shapes efficiently.
523+
524+ Computes embeddings for the maximum grid size once, then extracts
525+ and flattens the relevant portions for each requested shape.
526+
527+ Args:
528+ shapes: List of (H, W) tuples representing different grid sizes
529+
530+ Returns:
531+ List of concatenated sin/cos embeddings for each shape,
532+ where each tensor has shape (H*W, dim)
533+ """
534+ if not shapes :
535+ return []
536+
537+ # Check if we have pre-computed bands
538+ if self .bands is None :
539+ # If we have pre-computed pos_embed for a fixed shape, we can't do batch generation
540+ raise RuntimeError ("Batch embedding generation requires cached bands, not pre-computed embeddings" )
541+
542+ # Find max dimensions across all shapes
543+ max_h = max (h for h , w in shapes )
544+ max_w = max (w for h , w in shapes )
545+
546+ # Generate embeddings for max size ONCE
547+ sin_emb , cos_emb = build_rotary_pos_embed (
548+ feat_shape = (max_h , max_w ),
549+ bands = self .bands ,
550+ in_pixels = self .in_pixels ,
551+ ref_feat_shape = self .ref_feat_shape ,
552+ grid_offset = self .grid_offset ,
553+ grid_indexing = self .grid_indexing ,
554+ )
555+
556+ # sin_emb and cos_emb are (max_h * max_w, dim//2)
557+ # concat and reshape to 2D for slicing
558+ rope_embed_2d = torch .cat ([sin_emb , cos_emb ], dim = - 1 ).view (max_h , max_w , - 1 )
559+
560+ if seq_len is not None :
561+ flat_embeds = torch .zeros (len (shapes ), seq_len , rope_embed_2d .shape [- 1 ]).type_as (sin_emb )
562+ for i , (h , w ) in enumerate (shapes ):
563+ src_len = h * w
564+ flat_embeds [i , :src_len ] = rope_embed_2d [:h , :w ].reshape (src_len , - 1 )
565+ return flat_embeds
566+ else :
567+ flat_embeds_list = [rope_embed_2d [:h , :w ].reshape (h * w , - 1 ) for h , w in shapes ]
568+ return flat_embeds_list
569+
487570 def forward (self , x ):
488571 # assuming channel-first tensor where spatial dim are >= 2
489572 pos_embed = self .get_embed (x .shape [2 :])
@@ -642,6 +725,62 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
642725
643726 return get_mixed_freqs (self .freqs , t_x , t_y )
644727
728+ def get_batch_embeds (
729+ self ,
730+ shapes : List [Tuple [int , int ]],
731+ seq_len : Optional [int ] = None ,
732+ ) -> Union [torch .Tensor , List [torch .Tensor ]]:
733+ """Generate ROPE embeddings for multiple grid shapes efficiently.
734+
735+ Computes embeddings for the maximum grid size once, then extracts
736+ and flattens the relevant portions for each requested shape.
737+
738+ Args:
739+ shapes: List of (H, W) tuples representing different grid sizes
740+ seq_len: If provided, return padded tensor of this length. Otherwise return list.
741+
742+ Returns:
743+ If seq_len is provided: Padded tensor of shape (len(shapes), depth, num_heads, seq_len, dim)
744+ Otherwise: List of tensors with shape (depth, num_heads, H*W, dim) for each shape
745+ """
746+ if not shapes :
747+ return []
748+
749+ # Find max dimensions
750+ max_h = max (h for h , w in shapes )
751+ max_w = max (w for h , w in shapes )
752+
753+ # Generate embeddings for max size ONCE
754+ t_x , t_y = get_mixed_grid (
755+ [max_h , max_w ],
756+ grid_indexing = self .grid_indexing ,
757+ device = self .freqs .device
758+ )
759+ max_embed = get_mixed_freqs (self .freqs , t_x , t_y ) # (depth, num_heads, max_h*max_w, dim)
760+
761+ # Reshape to 2D grid for easy slicing
762+ depth , num_heads , _ , dim = max_embed .shape
763+ max_embed_2d = max_embed .view (depth , num_heads , max_h , max_w , dim )
764+
765+ if seq_len is not None :
766+ # Return padded tensor
767+ B = len (shapes )
768+ padded = torch .zeros (B , depth , num_heads , seq_len , dim , device = self .freqs .device , dtype = self .freqs .dtype )
769+ for i , (h , w ) in enumerate (shapes ):
770+ # Slice and flatten
771+ embed_slice = max_embed_2d [:, :, :h , :w ].reshape (depth , num_heads , h * w , dim )
772+ actual_len = h * w
773+ padded [i , :, :, :actual_len ] = embed_slice
774+ return padded
775+ else :
776+ # Return list
777+ results = []
778+ for h , w in shapes :
779+ # Slice and flatten
780+ embed_slice = max_embed_2d [:, :, :h , :w ].reshape (depth , num_heads , h * w , dim )
781+ results .append (embed_slice )
782+ return results
783+
645784 def forward (self , x ):
646785 # assuming channel-first tensor where spatial dim are >= 2
647786 pos_embed = self .get_embed (x .shape [2 :])
0 commit comments