@@ -270,47 +270,33 @@ def _compute_resize_matrix(
270270
271271 eye_matrix = torch .eye (old_total , device = device , dtype = dtype )
272272 basis_vectors_batch = eye_matrix .reshape (old_total , 1 , old_h , old_w )
273-
274273 resized_basis_vectors_batch = F .interpolate (
275274 basis_vectors_batch ,
276275 size = new_size ,
277276 mode = interpolation ,
278277 antialias = antialias ,
279278 align_corners = False
280279 ) # Output shape: (old_total, 1, new_h, new_w)
281-
282- resize_matrix = resized_basis_vectors_batch .squeeze (1 ).reshape (old_total , new_total ).T
280+ resize_matrix = resized_basis_vectors_batch .squeeze (1 ).permute (1 , 2 , 0 ).reshape (new_total , old_total )
283281 return resize_matrix # Shape: (new_total, old_total)
284282
285283
286- def _compute_pinv_for_resampling (resize_matrix : torch .Tensor ) -> torch .Tensor :
287- """Calculates the pseudoinverse matrix used for the resampling operation."""
288- pinv_matrix = torch .linalg .pinv (resize_matrix .T ) # Shape: (new_total, old_total)
289- return pinv_matrix
290-
291-
292284def _apply_resampling (
293285 patch_embed : torch .Tensor ,
294286 pinv_matrix : torch .Tensor ,
295287 new_size_tuple : Tuple [int , int ],
296288 orig_dtype : torch .dtype ,
297289 intermediate_dtype : torch .dtype = DTYPE_INTERMEDIATE
298290) -> torch .Tensor :
299- """Applies the precomputed pinv_matrix to resample the patch_embed tensor."""
300- try :
301- from torch import vmap
302- except ImportError :
303- from functorch import vmap
304-
305- def resample_kernel (kernel : torch .Tensor ) -> torch .Tensor :
306- kernel_flat = kernel .reshape (- 1 ).to (intermediate_dtype )
307- resampled_kernel_flat = pinv_matrix @ kernel_flat
308- return resampled_kernel_flat .reshape (new_size_tuple )
309-
310- resample_kernel_vmap = vmap (vmap (resample_kernel , in_dims = 0 , out_dims = 0 ), in_dims = 0 , out_dims = 0 )
311- patch_embed_float = patch_embed .to (intermediate_dtype )
312- resampled_patch_embed = resample_kernel_vmap (patch_embed_float )
313- return resampled_patch_embed .to (orig_dtype )
291+ """ Simplified resampling w/o vmap use.
292+ As proposed by https://github.com/stas-sl
293+ """
294+ c_out , c_in , * _ = patch_embed .shape
295+ patch_embed = patch_embed .reshape (c_out , c_in , - 1 ).to (dtype = intermediate_dtype )
296+ pinv_matrix = pinv_matrix .to (dtype = intermediate_dtype )
297+ resampled_patch_embed = patch_embed @ pinv_matrix # (C_out, C_in, P_old * P_old) @ (P_old * P_old, P_new * P_new)
298+ resampled_patch_embed = resampled_patch_embed .reshape (c_out , c_in , * new_size_tuple ).to (dtype = orig_dtype )
299+ return resampled_patch_embed
314300
315301
316302def resample_patch_embed (
@@ -336,7 +322,7 @@ def resample_patch_embed(
336322 resize_mat = _compute_resize_matrix (
337323 old_size_tuple , new_size_tuple , interpolation , antialias , device , DTYPE_INTERMEDIATE
338324 )
339- pinv_matrix = _compute_pinv_for_resampling (resize_mat )
325+ pinv_matrix = torch . linalg . pinv (resize_mat ) # Calculates the pseudoinverse matrix used for resampling
340326 resampled_patch_embed = _apply_resampling (
341327 patch_embed , pinv_matrix , new_size_tuple , orig_dtype , DTYPE_INTERMEDIATE
342328 )
@@ -388,7 +374,7 @@ def _get_or_create_pinv_matrix(
388374 resize_mat = _compute_resize_matrix (
389375 self .orig_size , new_size , self .interpolation , self .antialias , device , dtype
390376 )
391- pinv_matrix = _compute_pinv_for_resampling (resize_mat )
377+ pinv_matrix = torch . linalg . pinv (resize_mat ) # Calculates the pseudoinverse matrix used for resampling
392378
393379 # Cache using register_buffer
394380 buffer_name = f"pinv_{ new_size [0 ]} x{ new_size [1 ]} "
0 commit comments