@@ -352,7 +352,7 @@ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = Fa
352352 nn .init .ones_ (m .weight )
353353
354354
355- def resize_pos_embed (posemb , posemb_new , num_tokens = 1 ):
355+ def resize_pos_embed (posemb , posemb_new , num_tokens = 1 , gs_new = [] ):
356356 # Rescale the grid of position embeddings when loading from state_dict. Adapted from
357357 # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
358358 _logger .info ('Resized position embedding: %s to %s' , posemb .shape , posemb_new .shape )
@@ -363,11 +363,12 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1):
363363 else :
364364 posemb_tok , posemb_grid = posemb [:, :0 ], posemb [0 ]
365365 gs_old = int (math .sqrt (len (posemb_grid )))
366- gs_new = int (math .sqrt (ntok_new ))
367- _logger .info ('Position embedding grid-size from %s to %s' , gs_old , gs_new )
366+ if not len (gs_new ): # backwards compatibility
367+ gs_new = [int (math .sqrt (ntok_new ))]* 2
368+ _logger .info ('Position embedding grid-size from %s to %s' , [gs_old , gs_old ], gs_new )
368369 posemb_grid = posemb_grid .reshape (1 , gs_old , gs_old , - 1 ).permute (0 , 3 , 1 , 2 )
369- posemb_grid = F .interpolate (posemb_grid , size = ( gs_new , gs_new ) , mode = 'bilinear' )
370- posemb_grid = posemb_grid .permute (0 , 2 , 3 , 1 ).reshape (1 , gs_new * gs_new , - 1 )
370+ posemb_grid = F .interpolate (posemb_grid , size = gs_new , mode = 'bilinear' )
371+ posemb_grid = posemb_grid .permute (0 , 2 , 3 , 1 ).reshape (1 , gs_new [ 0 ] * gs_new [ 1 ] , - 1 )
371372 posemb = torch .cat ([posemb_tok , posemb_grid ], dim = 1 )
372373 return posemb
373374
@@ -385,7 +386,8 @@ def checkpoint_filter_fn(state_dict, model):
385386 v = v .reshape (O , - 1 , H , W )
386387 elif k == 'pos_embed' and v .shape != model .pos_embed .shape :
387388 # To resize pos embedding when using model at different size from pretrained weights
388- v = resize_pos_embed (v , model .pos_embed , getattr (model , 'num_tokens' , 1 ))
389+ v = resize_pos_embed (v , model .pos_embed , getattr (model , 'num_tokens' , 1 ),
390+ model .patch_embed .grid_size )
389391 out_dict [k ] = v
390392 return out_dict
391393
0 commit comments