Skip to content

Commit dac2ec6

Browse files
committed
Add missing patch embed interpolator
1 parent 0d43942 commit dac2ec6

File tree

2 files changed

+161
-2
lines changed

2 files changed

+161
-2
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@
4141
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
4242
from .padding import get_padding, get_same_padding, pad_same
4343
from .patch_dropout import PatchDropout
44-
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
45-
from .patch_embed_interpolator import PatchEmbedInterpolator
44+
from .patch_embed import PatchEmbed, PatchEmbedWithSize, PatchEmbedInterpolator, resample_patch_embed
4645
from .pool1d import global_pool_nlc
4746
from .pool2d_same import AvgPool2dSame, create_pool2d
4847
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc

timm/layers/patch_embed.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,166 @@ def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tenso
436436
return resampled_patch_embed
437437

438438

439+
class PatchEmbedInterpolator(nn.Module):
440+
"""Dynamically interpolates patch embedding weights for variable patch sizes.
441+
442+
This module wraps patch embedding weight resampling functionality to support
443+
on-the-fly patch size variation during training. It handles both Conv2d and
444+
Linear patch embeddings.
445+
446+
Args:
447+
base_patch_size: The original patch size the model was initialized with
448+
in_chans: Number of input channels
449+
embed_dim: Embedding dimension
450+
interpolation: Interpolation mode for resampling
451+
antialias: Whether to use antialiasing during interpolation
452+
"""
453+
454+
def __init__(
455+
self,
456+
base_patch_size: Tuple[int, int],
457+
in_chans: int = 3,
458+
embed_dim: int = 768,
459+
interpolation: str = 'bicubic',
460+
antialias: bool = True,
461+
):
462+
super().__init__()
463+
self.base_patch_size = base_patch_size
464+
self.in_chans = in_chans
465+
self.embed_dim = embed_dim
466+
self.interpolation = interpolation
467+
self.antialias = antialias
468+
469+
def resample_linear_weight(
470+
self,
471+
weight: torch.Tensor,
472+
target_patch_size: Tuple[int, int],
473+
) -> torch.Tensor:
474+
"""Resample linear patch embedding weights for a new patch size.
475+
476+
Args:
477+
weight: Linear weight tensor of shape [embed_dim, patch_h * patch_w * in_chans]
478+
target_patch_size: Target (patch_h, patch_w) to resample to
479+
480+
Returns:
481+
Resampled weight tensor
482+
"""
483+
if target_patch_size == self.base_patch_size:
484+
return weight
485+
486+
embed_dim = weight.shape[0]
487+
base_ph, base_pw = self.base_patch_size
488+
target_ph, target_pw = target_patch_size
489+
490+
# Reshape linear weight to conv2d format
491+
# [embed_dim, ph*pw*C] -> [embed_dim, C, ph, pw]
492+
weight_conv = weight.reshape(embed_dim, base_ph, base_pw, self.in_chans)
493+
weight_conv = weight_conv.permute(0, 3, 1, 2)
494+
495+
# Resample using existing function
496+
weight_conv_resampled = resample_patch_embed(
497+
weight_conv,
498+
new_size=[target_ph, target_pw],
499+
interpolation=self.interpolation,
500+
antialias=self.antialias,
501+
verbose=False,
502+
)
503+
504+
# Reshape back to linear format
505+
# [embed_dim, C, ph, pw] -> [embed_dim, ph*pw*C]
506+
weight_resampled = weight_conv_resampled.permute(0, 2, 3, 1)
507+
weight_resampled = weight_resampled.reshape(embed_dim, -1)
508+
509+
return weight_resampled
510+
511+
def resample_conv_weight(
512+
self,
513+
weight: torch.Tensor,
514+
target_patch_size: Tuple[int, int],
515+
) -> torch.Tensor:
516+
"""Resample conv2d patch embedding weights for a new patch size.
517+
518+
Args:
519+
weight: Conv2d weight tensor of shape [embed_dim, in_chans, patch_h, patch_w]
520+
target_patch_size: Target (patch_h, patch_w) to resample to
521+
522+
Returns:
523+
Resampled weight tensor
524+
"""
525+
if target_patch_size == self.base_patch_size:
526+
return weight
527+
528+
# Resample using existing function
529+
weight_resampled = resample_patch_embed(
530+
weight,
531+
new_size=list(target_patch_size),
532+
interpolation=self.interpolation,
533+
antialias=self.antialias,
534+
verbose=False,
535+
)
536+
537+
return weight_resampled
538+
539+
def forward(
540+
self,
541+
patches: torch.Tensor,
542+
proj_weight: torch.Tensor,
543+
proj_bias: Optional[torch.Tensor] = None,
544+
patch_size: Optional[Tuple[int, int]] = None,
545+
is_linear: bool = True,
546+
) -> torch.Tensor:
547+
"""Apply patch embedding with dynamic weight resampling.
548+
549+
Args:
550+
patches: Input patches
551+
- For linear mode with resampling: [B, N, Ph, Pw, C]
552+
- For linear mode without resampling: [B, N, Ph*Pw*C]
553+
- For conv mode: [B, C, H, W]
554+
proj_weight: Original projection weight
555+
proj_bias: Optional projection bias
556+
patch_size: Current patch size (if None, uses base_patch_size)
557+
is_linear: Whether using linear (True) or conv2d (False) projection
558+
559+
Returns:
560+
Embedded patches
561+
"""
562+
if patch_size is None:
563+
patch_size = self.base_patch_size
564+
565+
if is_linear:
566+
if patch_size != self.base_patch_size:
567+
# Need to resample - expects unflattened patches
568+
assert patches.ndim == 5, "Patches must be [B, N, Ph, Pw, C] for resampling"
569+
B, N, Ph, Pw, C = patches.shape
570+
571+
# Resample the weight
572+
weight_resampled = self.resample_linear_weight(proj_weight, patch_size)
573+
574+
# Flatten patches and apply linear projection
575+
patches_flat = patches.reshape(B, N, -1)
576+
output = torch.nn.functional.linear(patches_flat, weight_resampled, proj_bias)
577+
else:
578+
# No resampling needed, patches can be pre-flattened
579+
if patches.ndim == 5:
580+
B, N, Ph, Pw, C = patches.shape
581+
patches = patches.reshape(B, N, -1)
582+
output = torch.nn.functional.linear(patches, proj_weight, proj_bias)
583+
else:
584+
# Conv mode
585+
if patch_size != self.base_patch_size:
586+
weight_resampled = self.resample_conv_weight(proj_weight, patch_size)
587+
output = torch.nn.functional.conv2d(
588+
patches, weight_resampled, proj_bias,
589+
stride=patch_size, padding=0
590+
)
591+
else:
592+
output = torch.nn.functional.conv2d(
593+
patches, proj_weight, proj_bias,
594+
stride=patch_size, padding=0
595+
)
596+
597+
return output
598+
439599
# def divs(n, m=None):
440600
# m = m or n // 2
441601
# if m == 1:

0 commit comments

Comments
 (0)