File tree Expand file tree Collapse file tree 1 file changed +9
-20
lines changed Expand file tree Collapse file tree 1 file changed +9
-20
lines changed Original file line number Diff line number Diff line change @@ -1192,27 +1192,16 @@ def _pool(
11921192 patch_valid : Optional [torch .Tensor ] = None ,
11931193 ) -> torch .Tensor :
11941194 if self .attn_pool is not None :
1195- # For attention pooling, we need to pass the mask for NaFlex models
1195+ attn_mask = create_attention_mask (
1196+ patch_valid ,
1197+ num_prefix_tokens = self .num_prefix_tokens if self .pool_include_prefix else 0 ,
1198+ symmetric = False ,
1199+ q_len = 1 ,
1200+ dtype = x .dtype ,
1201+ )
11961202 if self .pool_include_prefix :
1197- # Include all tokens in attention pooling - create mask for all tokens including prefix
1198- attn_mask = create_attention_mask (
1199- patch_valid ,
1200- num_prefix_tokens = self .num_prefix_tokens ,
1201- symmetric = False ,
1202- q_len = 1 ,
1203- dtype = x .dtype ,
1204- )
1205- x = self .attn_pool (x , attn_mask = attn_mask )
1206- else :
1207- # Exclude prefix tokens from attention pooling (default behavior)
1208- attn_mask = create_attention_mask (
1209- patch_valid ,
1210- num_prefix_tokens = 0 , # No prefix tokens when we slice them off
1211- symmetric = False ,
1212- q_len = 1 ,
1213- dtype = x .dtype ,
1214- )
1215- x = self .attn_pool (x [:, self .num_prefix_tokens :], attn_mask = attn_mask )
1203+ x = x [:, self .num_prefix_tokens :]
1204+ x = self .attn_pool (x , attn_mask = attn_mask )
12161205 return x
12171206
12181207 pool_type = self .global_pool if pool_type is None else pool_type
You can’t perform that action at this time.
0 commit comments