File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -106,12 +106,16 @@ class NaFlexVitCfg:
106106 # Image processing
107107 dynamic_img_pad : bool = False # Whether to enable dynamic padding for variable resolution
108108
109- # Architecture choices
109+ # Other architecture choices
110110 pre_norm : bool = False # Whether to apply normalization before attention/MLP layers (start of blocks)
111111 final_norm : bool = True # Whether to apply final normalization before pooling and classifier (end of blocks)
112112 fc_norm : Optional [bool ] = None # Whether to normalize features before final classifier (after pooling)
113+
114+ # Global pooling setup
113115 global_pool : str = 'map' # Type of global pooling for final sequence
114116 pool_include_prefix : bool = False # Whether to include class/register prefix tokens in global pooling
117+ attn_pool_num_heads : Optional [int ] = None # Override num_heads for attention pool
118+ attn_pool_mlp_ratio : Optional [float ] = None # Override mlp_ratio for attention pool
115119
116120 # Weight initialization
117121 weight_init : str = '' # Weight initialization scheme
@@ -1212,8 +1216,8 @@ def __init__(
12121216 if cfg .global_pool == 'map' :
12131217 self .attn_pool = AttentionPoolLatent (
12141218 self .embed_dim ,
1215- num_heads = cfg .num_heads ,
1216- mlp_ratio = cfg .mlp_ratio ,
1219+ num_heads = cfg .attn_pool_num_heads or cfg . num_heads ,
1220+ mlp_ratio = cfg .attn_pool_mlp_ratio or cfg . mlp_ratio ,
12171221 norm_layer = norm_layer ,
12181222 act_layer = act_layer ,
12191223 )
You can’t perform that action at this time.
0 commit comments