2020import torch .nn .functional as F
2121from torch import nn
2222
23- from timm .layers import to_2tuple , make_divisible , GroupNorm1 , ConvMlp , DropPath
23+ from timm .layers import to_2tuple , make_divisible , GroupNorm1 , ConvMlp , DropPath , is_exportable
2424from ._builder import build_model_with_cfg
2525from ._features_fx import register_notrace_module
2626from ._registry import register_model
@@ -564,6 +564,7 @@ def __init__(
564564
565565 self .patch_size = to_2tuple (patch_size )
566566 self .patch_area = self .patch_size [0 ] * self .patch_size [1 ]
567+ self .coreml_exportable = is_exportable ()
567568
568569 def forward (self , x : torch .Tensor ) -> torch .Tensor :
569570 B , C , H , W = x .shape
@@ -580,16 +581,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
580581
581582 # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
582583 C = x .shape [1 ]
583- x = x .reshape (B , C , num_patch_h , patch_h , num_patch_w , patch_w ).permute (0 , 1 , 3 , 5 , 2 , 4 )
584+ if self .coreml_exportable :
585+ x = F .unfold (x , kernel_size = (patch_h , patch_w ), stride = (patch_h , patch_w ))
586+ else :
587+ x = x .reshape (B , C , num_patch_h , patch_h , num_patch_w , patch_w ).permute (0 , 1 , 3 , 5 , 2 , 4 )
584588 x = x .reshape (B , C , - 1 , num_patches )
585589
586590 # Global representations
587591 x = self .transformer (x )
588592 x = self .norm (x )
589593
590594 # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
591- x = x .reshape (B , C , patch_h , patch_w , num_patch_h , num_patch_w ).permute (0 , 1 , 4 , 2 , 5 , 3 )
592- x = x .reshape (B , C , num_patch_h * patch_h , num_patch_w * patch_w )
595+ if self .coreml_exportable :
596+ # adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
597+ x = x .reshape (B , C * patch_h * patch_w , num_patch_h , num_patch_w )
598+ x = F .pixel_shuffle (x , upscale_factor = patch_h )
599+ else :
600+ x = x .reshape (B , C , patch_h , patch_w , num_patch_h , num_patch_w ).permute (0 , 1 , 4 , 2 , 5 , 3 )
601+ x = x .reshape (B , C , num_patch_h * patch_h , num_patch_w * patch_w )
602+
593603
594604 x = self .conv_proj (x )
595605 return x
0 commit comments