@@ -436,6 +436,166 @@ def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tenso
436436 return resampled_patch_embed
437437
438438
439+ class PatchEmbedInterpolator (nn .Module ):
440+ """Dynamically interpolates patch embedding weights for variable patch sizes.
441+
442+ This module wraps patch embedding weight resampling functionality to support
443+ on-the-fly patch size variation during training. It handles both Conv2d and
444+ Linear patch embeddings.
445+
446+ Args:
447+ base_patch_size: The original patch size the model was initialized with
448+ in_chans: Number of input channels
449+ embed_dim: Embedding dimension
450+ interpolation: Interpolation mode for resampling
451+ antialias: Whether to use antialiasing during interpolation
452+ """
453+
454+ def __init__ (
455+ self ,
456+ base_patch_size : Tuple [int , int ],
457+ in_chans : int = 3 ,
458+ embed_dim : int = 768 ,
459+ interpolation : str = 'bicubic' ,
460+ antialias : bool = True ,
461+ ):
462+ super ().__init__ ()
463+ self .base_patch_size = base_patch_size
464+ self .in_chans = in_chans
465+ self .embed_dim = embed_dim
466+ self .interpolation = interpolation
467+ self .antialias = antialias
468+
469+ def resample_linear_weight (
470+ self ,
471+ weight : torch .Tensor ,
472+ target_patch_size : Tuple [int , int ],
473+ ) -> torch .Tensor :
474+ """Resample linear patch embedding weights for a new patch size.
475+
476+ Args:
477+ weight: Linear weight tensor of shape [embed_dim, patch_h * patch_w * in_chans]
478+ target_patch_size: Target (patch_h, patch_w) to resample to
479+
480+ Returns:
481+ Resampled weight tensor
482+ """
483+ if target_patch_size == self .base_patch_size :
484+ return weight
485+
486+ embed_dim = weight .shape [0 ]
487+ base_ph , base_pw = self .base_patch_size
488+ target_ph , target_pw = target_patch_size
489+
490+ # Reshape linear weight to conv2d format
491+ # [embed_dim, ph*pw*C] -> [embed_dim, C, ph, pw]
492+ weight_conv = weight .reshape (embed_dim , base_ph , base_pw , self .in_chans )
493+ weight_conv = weight_conv .permute (0 , 3 , 1 , 2 )
494+
495+ # Resample using existing function
496+ weight_conv_resampled = resample_patch_embed (
497+ weight_conv ,
498+ new_size = [target_ph , target_pw ],
499+ interpolation = self .interpolation ,
500+ antialias = self .antialias ,
501+ verbose = False ,
502+ )
503+
504+ # Reshape back to linear format
505+ # [embed_dim, C, ph, pw] -> [embed_dim, ph*pw*C]
506+ weight_resampled = weight_conv_resampled .permute (0 , 2 , 3 , 1 )
507+ weight_resampled = weight_resampled .reshape (embed_dim , - 1 )
508+
509+ return weight_resampled
510+
511+ def resample_conv_weight (
512+ self ,
513+ weight : torch .Tensor ,
514+ target_patch_size : Tuple [int , int ],
515+ ) -> torch .Tensor :
516+ """Resample conv2d patch embedding weights for a new patch size.
517+
518+ Args:
519+ weight: Conv2d weight tensor of shape [embed_dim, in_chans, patch_h, patch_w]
520+ target_patch_size: Target (patch_h, patch_w) to resample to
521+
522+ Returns:
523+ Resampled weight tensor
524+ """
525+ if target_patch_size == self .base_patch_size :
526+ return weight
527+
528+ # Resample using existing function
529+ weight_resampled = resample_patch_embed (
530+ weight ,
531+ new_size = list (target_patch_size ),
532+ interpolation = self .interpolation ,
533+ antialias = self .antialias ,
534+ verbose = False ,
535+ )
536+
537+ return weight_resampled
538+
539+ def forward (
540+ self ,
541+ patches : torch .Tensor ,
542+ proj_weight : torch .Tensor ,
543+ proj_bias : Optional [torch .Tensor ] = None ,
544+ patch_size : Optional [Tuple [int , int ]] = None ,
545+ is_linear : bool = True ,
546+ ) -> torch .Tensor :
547+ """Apply patch embedding with dynamic weight resampling.
548+
549+ Args:
550+ patches: Input patches
551+ - For linear mode with resampling: [B, N, Ph, Pw, C]
552+ - For linear mode without resampling: [B, N, Ph*Pw*C]
553+ - For conv mode: [B, C, H, W]
554+ proj_weight: Original projection weight
555+ proj_bias: Optional projection bias
556+ patch_size: Current patch size (if None, uses base_patch_size)
557+ is_linear: Whether using linear (True) or conv2d (False) projection
558+
559+ Returns:
560+ Embedded patches
561+ """
562+ if patch_size is None :
563+ patch_size = self .base_patch_size
564+
565+ if is_linear :
566+ if patch_size != self .base_patch_size :
567+ # Need to resample - expects unflattened patches
568+ assert patches .ndim == 5 , "Patches must be [B, N, Ph, Pw, C] for resampling"
569+ B , N , Ph , Pw , C = patches .shape
570+
571+ # Resample the weight
572+ weight_resampled = self .resample_linear_weight (proj_weight , patch_size )
573+
574+ # Flatten patches and apply linear projection
575+ patches_flat = patches .reshape (B , N , - 1 )
576+ output = torch .nn .functional .linear (patches_flat , weight_resampled , proj_bias )
577+ else :
578+ # No resampling needed, patches can be pre-flattened
579+ if patches .ndim == 5 :
580+ B , N , Ph , Pw , C = patches .shape
581+ patches = patches .reshape (B , N , - 1 )
582+ output = torch .nn .functional .linear (patches , proj_weight , proj_bias )
583+ else :
584+ # Conv mode
585+ if patch_size != self .base_patch_size :
586+ weight_resampled = self .resample_conv_weight (proj_weight , patch_size )
587+ output = torch .nn .functional .conv2d (
588+ patches , weight_resampled , proj_bias ,
589+ stride = patch_size , padding = 0
590+ )
591+ else :
592+ output = torch .nn .functional .conv2d (
593+ patches , proj_weight , proj_bias ,
594+ stride = patch_size , padding = 0
595+ )
596+
597+ return output
598+
439599# def divs(n, m=None):
440600# m = m or n // 2
441601# if m == 1:
0 commit comments