@@ -771,28 +771,20 @@ def resize_pos_embed(
771771 antialias : bool = False ,
772772) -> torch .Tensor :
773773 """ Rescale the grid of position embeddings when loading from state_dict.
774-
775- *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed
776-
777- Adapted from:
778- https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
774+ *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
779775 """
780- ntok_new = posemb_new .shape [1 ]
781- if num_prefix_tokens :
782- posemb_prefix , posemb_grid = posemb [:, :num_prefix_tokens ], posemb [0 , num_prefix_tokens :]
783- ntok_new -= num_prefix_tokens
784- else :
785- posemb_prefix , posemb_grid = posemb [:, :0 ], posemb [0 ]
786- gs_old = int (math .sqrt (len (posemb_grid )))
776+ ntok_new = posemb_new .shape [1 ] - num_prefix_tokens
777+ ntok_old = posemb .shape [1 ] - num_prefix_tokens
778+ gs_old = [int (math .sqrt (ntok_old ))] * 2
787779 if not len (gs_new ): # backwards compatibility
788780 gs_new = [int (math .sqrt (ntok_new ))] * 2
789- assert len ( gs_new ) >= 2
790- _logger . info ( f'Resized position embedding: { posemb . shape } ( { [ gs_old , gs_old ] } ) to { posemb_new . shape } ( { gs_new } ).' )
791- posemb_grid = posemb_grid . reshape ( 1 , gs_old , gs_old , - 1 ). permute ( 0 , 3 , 1 , 2 )
792- posemb_grid = F . interpolate ( posemb_grid , size = gs_new , mode = interpolation , antialias = antialias , align_corners = False )
793- posemb_grid = posemb_grid . permute ( 0 , 2 , 3 , 1 ). reshape ( 1 , gs_new [ 0 ] * gs_new [ 1 ], - 1 )
794- posemb = torch . cat ([ posemb_prefix , posemb_grid ], dim = 1 )
795- return posemb
781+ return resample_abs_pos_embed (
782+ posemb , gs_new , gs_old ,
783+ num_prefix_tokens = num_prefix_tokens ,
784+ interpolation = interpolation ,
785+ antialias = antialias ,
786+ verbose = True ,
787+ )
796788
797789
798790@torch .no_grad ()
@@ -962,16 +954,6 @@ def _convert_openai_clip(
962954 v = v .unsqueeze (0 ).unsqueeze (1 )
963955 elif k == 'pos_embed' :
964956 v = v .unsqueeze (0 )
965- if v .shape [1 ] != model .pos_embed .shape [1 ]:
966- # To resize pos embedding when using model at different size from pretrained weights
967- num_prefix_tokens = 0 if getattr (model , 'no_embed_class' , False ) \
968- else getattr (model , 'num_prefix_tokens' , 1 )
969- v = resample_abs_pos_embed (
970- v ,
971- new_size = model .patch_embed .grid_size ,
972- num_prefix_tokens = num_prefix_tokens ,
973- verbose = True ,
974- )
975957 out_dict [k ] = v
976958 return out_dict
977959
@@ -1014,19 +996,17 @@ def checkpoint_filter_fn(
1014996 prefix = ''
1015997
1016998 if 'visual.class_embedding' in state_dict :
1017- return _convert_openai_clip (state_dict , model )
999+ state_dict = _convert_openai_clip (state_dict , model )
10181000 elif 'module.visual.class_embedding' in state_dict :
1019- return _convert_openai_clip (state_dict , model , prefix = 'module.visual.' )
1020-
1021- if "mask_token" in state_dict :
1001+ state_dict = _convert_openai_clip (state_dict , model , prefix = 'module.visual.' )
1002+ elif "mask_token" in state_dict :
10221003 state_dict = _convert_dinov2 (state_dict , model )
1023-
1024- if "encoder" in state_dict :
1004+ elif "encoder" in state_dict :
1005+ # IJEPA, vit in an 'encoder' submodule
10251006 state_dict = state_dict ['encoder' ]
10261007 prefix = 'module.'
1027-
1028- if 'visual.trunk.pos_embed' in state_dict :
1029- # convert an OpenCLIP model with timm vision encoder
1008+ elif 'visual.trunk.pos_embed' in state_dict :
1009+ # OpenCLIP model with timm vision encoder
10301010 # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
10311011 prefix = 'visual.trunk.'
10321012
0 commit comments