@@ -2003,42 +2003,21 @@ def _create_naflexvit_from_eva(
20032003 Returns:
20042004 NaFlexVit model instance
20052005 """
2006- # Map EVA-specific parameters to NaFlexVit equivalents
2007-
2008- # Handle EVA's unique parameters
2009- kwargs .pop ('no_embed_class' , None ) # EVA specific, not used in NaFlexVit
2010-
2011- # abs pos embed
2012- use_abs_pos_emb = kwargs .pop ('use_abs_pos_emb' , True )
2006+ # Handle EVA's unique parameters & block args
2007+ kwargs .pop ('no_embed_class' , None ) # EVA specific, not used in NaFlexVit (always no-embed)
20132008
20142009 # Map EVA's rope parameters
20152010 use_rot_pos_emb = kwargs .pop ('use_rot_pos_emb' , False )
20162011 rope_mixed_mode = kwargs .pop ('rope_mixed_mode' , False )
20172012 rope_temperature = kwargs .pop ('rope_temperature' , 10000. )
20182013 rope_grid_offset = kwargs .pop ('rope_grid_offset' , 0. )
20192014 rope_grid_indexing = kwargs .pop ('rope_grid_indexing' , 'ij' )
2020-
2021- # Get EVA's attn_type directly
2022- attn_type = kwargs .pop ('attn_type' , 'eva' )
2023-
2024- # Determine rope_type based on EVA parameters
20252015 if use_rot_pos_emb :
20262016 rope_type = 'mixed' if rope_mixed_mode else 'axial'
20272017 else :
20282018 rope_type = 'none'
20292019
2030- # Handle EVA's swiglu_mlp and scale_mlp
2031- swiglu_mlp = kwargs .pop ('swiglu_mlp' , False )
2032- scale_mlp = kwargs .pop ('scale_mlp' , False )
2033- scale_attn_inner = kwargs .pop ('scale_attn_inner' , False )
2034-
2035- # Map qkv_fused parameter
2036- qkv_fused = kwargs .pop ('qkv_fused' , True )
2037-
2038- # Handle register tokens
2039- num_reg_tokens = kwargs .pop ('num_reg_tokens' , kwargs .get ('reg_tokens' , 0 ))
2040-
2041- # Handle global pooling
2020+ # Handle global pooling logic to mirror EVA use
20422021 gp = kwargs .pop ('global_pool' , 'avg' )
20432022 fc_norm = kwargs .pop ('fc_norm' , None )
20442023 if fc_norm is None and gp == 'avg' :
@@ -2048,23 +2027,22 @@ def _create_naflexvit_from_eva(
20482027 naflex_kwargs = {
20492028 'pos_embed_grid_size' : None , # rely on img_size (// patch_size)
20502029 'class_token' : kwargs .get ('class_token' , True ),
2051- 'reg_tokens' : num_reg_tokens ,
2030+ 'reg_tokens' : kwargs . pop ( ' num_reg_tokens' , kwargs . get ( 'reg_tokens' , 0 )) ,
20522031 'global_pool' : gp ,
20532032 'fc_norm' : fc_norm ,
2054- 'pos_embed' : 'learned' if use_abs_pos_emb else 'none' ,
2033+ 'pos_embed' : 'learned' if kwargs . pop ( ' use_abs_pos_emb' , True ) else 'none' ,
20552034 'rope_type' : rope_type ,
20562035 'rope_temperature' : rope_temperature ,
20572036 'rope_grid_offset' : rope_grid_offset ,
20582037 'rope_grid_indexing' : rope_grid_indexing ,
20592038 'rope_ref_feat_shape' : kwargs .get ('ref_feat_shape' , None ),
2060- 'attn_type' : attn_type ,
2061- 'swiglu_mlp' : swiglu_mlp ,
2062- 'scale_mlp ' : scale_mlp ,
2063- 'scale_attn_inner ' : scale_attn_inner ,
2064- 'qkv_fused ' : qkv_fused ,
2039+ 'attn_type' : kwargs . pop ( ' attn_type' , 'eva' ) ,
2040+ 'swiglu_mlp' : kwargs . pop ( ' swiglu_mlp' , False ) ,
2041+ 'qkv_fused ' : kwargs . pop ( 'qkv_fused' , True ) ,
2042+ 'scale_mlp_norm ' : kwargs . pop ( 'scale_mlp' , False ) ,
2043+ 'scale_attn_inner_norm ' : kwargs . pop ( 'scale_attn_inner' , False ) ,
20652044 ** kwargs # Pass remaining kwargs through
20662045 }
2067- print (naflex_kwargs )
20682046
20692047 return _create_naflexvit (variant , pretrained , ** naflex_kwargs )
20702048
0 commit comments