@@ -352,7 +352,7 @@ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = Fa
352352 nn .init .ones_ (m .weight )
353353
354354
355- def resize_pos_embed (posemb , posemb_new , num_tokens = 1 , gs_new = [] ):
355+ def resize_pos_embed (posemb , posemb_new , num_tokens = 1 , gs_new = () ):
356356 # Rescale the grid of position embeddings when loading from state_dict. Adapted from
357357 # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
358358 _logger .info ('Resized position embedding: %s to %s' , posemb .shape , posemb_new .shape )
@@ -363,8 +363,9 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=[]):
363363 else :
364364 posemb_tok , posemb_grid = posemb [:, :0 ], posemb [0 ]
365365 gs_old = int (math .sqrt (len (posemb_grid )))
366- if not len (gs_new ): # backwards compatibility
367- gs_new = [int (math .sqrt (ntok_new ))]* 2
366+ if not len (gs_new ): # backwards compatibility
367+ gs_new = [int (math .sqrt (ntok_new ))] * 2
368+ assert len (gs_new ) >= 2
368369 _logger .info ('Position embedding grid-size from %s to %s' , [gs_old , gs_old ], gs_new )
369370 posemb_grid = posemb_grid .reshape (1 , gs_old , gs_old , - 1 ).permute (0 , 3 , 1 , 2 )
370371 posemb_grid = F .interpolate (posemb_grid , size = gs_new , mode = 'bilinear' )
0 commit comments