Skip to content

Commit 0991a3c

Browse files
author
Maxime France-Pillois
authored
Add key to FA Autotuner (#4199)
Add "STAGE" as a key for the autotuner to optimize the configuration for CAUSAL FA as well. --------- Signed-off-by: Maxime France-Pillois <maxime.francepillois@codeplay.com>
1 parent d6264cf commit 0991a3c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,12 @@ def _attn_fwd_with_block_pointers(Q, K, V, sm_scale, M, Out, #
160160
configs = [
161161
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'one_matrix_per_load_for_bt': True}, num_stages=s, num_warps=w) \
162162
for BM in [128, 256] \
163-
for BN in [32, 64, 128] \
163+
for BN in [32, 64] \
164164
for s in [2, 3, 4] \
165165
for w in [8, 16, 32] \
166166
]
167167

168-
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
168+
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL', 'STAGE'])
169169

170170

171171
@triton.jit

0 commit comments

Comments
 (0)