@@ -2017,19 +2017,23 @@ def _create_naflexvit_from_eva(
20172017 else :
20182018 rope_type = 'none'
20192019
2020- # Handle global pooling logic to mirror EVA use
2020+ # Handle norm/pool resolution logic to mirror EVA
20212021 gp = kwargs .pop ('global_pool' , 'avg' )
2022- fc_norm = kwargs .pop ('fc_norm' , None )
2023- if fc_norm is None and gp == 'avg' :
2024- fc_norm = True
2022+ use_pre_transformer_norm = kwargs .pop ('use_pre_transformer_norm' , False )
2023+ use_post_transformer_norm = kwargs .pop ('use_post_transformer_norm' , True )
2024+ use_fc_norm = kwargs .pop ('use_fc_norm' , None )
2025+ if use_fc_norm is None :
2026+ use_fc_norm = gp == 'avg' # default on if avg pool used
20252027
20262028 # Set NaFlexVit-specific parameters
20272029 naflex_kwargs = {
20282030 'pos_embed_grid_size' : None , # rely on img_size (// patch_size)
20292031 'class_token' : kwargs .get ('class_token' , True ),
20302032 'reg_tokens' : kwargs .pop ('num_reg_tokens' , kwargs .get ('reg_tokens' , 0 )),
20312033 'global_pool' : gp ,
2032- 'fc_norm' : fc_norm ,
2034+ 'pre_norm' : use_pre_transformer_norm ,
2035+ 'final_norm' : use_post_transformer_norm ,
2036+ 'fc_norm' : use_fc_norm ,
20332037 'pos_embed' : 'learned' if kwargs .pop ('use_abs_pos_emb' , True ) else 'none' ,
20342038 'rope_type' : rope_type ,
20352039 'rope_temperature' : rope_temperature ,
0 commit comments