1+ """
2+ Copyright (c) 2024 by FlashInfer team.
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ http://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ """
16+
117import numpy as np
218import torch
319
824)
925
1026
11- def bench_single_prefill (seq_len , num_heads , causal , head_dim ):
12- num_qo_heads = num_kv_heads = num_heads
13- q = torch .randn (seq_len , num_qo_heads , head_dim , dtype = torch .half , device = "cuda" )
14- k = torch .randn (seq_len , num_kv_heads , head_dim , dtype = torch .half , device = "cuda" )
15- v = torch .randn (seq_len , num_kv_heads , head_dim , dtype = torch .half , device = "cuda" )
16-
17- sm80_ms , sm90_ms = (
18- np .median (
19- bench_gpu_time (
20- lambda : flashinfer .single_prefill_with_kv_cache_return_lse (
21- q , k , v , causal = causal , backend = backend
22- ),
23- dry_run_time_ms = 100 ,
24- repeat_time_ms = 1000 ,
25- )
26- )
27- for backend in ["fa2" , "fa3" ]
27+ def per_head_symmetric_quant (x , quant_dtype ):
28+ """Per-head symmetric quantization to FP8."""
29+ o_min_val , o_max_val = (
30+ (- 448.0 , 448.0 ) if quant_dtype == torch .float8_e4m3fn else (- 57344 , 57344 )
31+ )
32+ x_max_val = x .abs ().amax (dim = (0 , 2 )).to (dtype = torch .float32 )
33+ s_out = torch .clamp (x_max_val / o_max_val , min = 1e-6 )
34+ s_out_broadcast = s_out .view (1 , - 1 , 1 )
35+ q_x_out = torch .clamp (x / s_out_broadcast , min = o_min_val , max = o_max_val ).to (
36+ dtype = quant_dtype
2837 )
38+ return q_x_out , s_out
2939
30- q = torch .randn (
40+
41+ def bench_fp8_single_prefill (
42+ seq_len , num_heads , causal , head_dim , dtype = torch .float8_e4m3fn
43+ ):
44+ """Benchmark FP8 single prefill attention."""
45+ num_qo_heads = num_kv_heads = num_heads
46+
47+ # Create FP16 tensors first, then quantize
48+ q_fp16 = torch .randn (
3149 seq_len , num_qo_heads , head_dim , dtype = torch .half , device = "cuda"
32- ). to ( dtype = torch . float8_e4m3fn )
33- k = torch .randn (
50+ )
51+ k_fp16 = torch .randn (
3452 seq_len , num_kv_heads , head_dim , dtype = torch .half , device = "cuda"
35- ). to ( dtype = torch . float8_e4m3fn )
36- v = torch .randn (
53+ )
54+ v_fp16 = torch .randn (
3755 seq_len , num_kv_heads , head_dim , dtype = torch .half , device = "cuda"
38- ).to (dtype = torch .float8_e4m3fn )
56+ )
57+
58+ # Quantize to FP8
59+ q_fp8 , s_q = per_head_symmetric_quant (q_fp16 , dtype )
60+ k_fp8 , s_k = per_head_symmetric_quant (k_fp16 , dtype )
61+ v_fp8 , s_v = per_head_symmetric_quant (v_fp16 , dtype )
3962
40- fp8_sm90_ms = np .median (
63+ # FP16 baseline (fa3)
64+ fp16_ms = np .median (
4165 bench_gpu_time (
4266 lambda : flashinfer .single_prefill_with_kv_cache_return_lse (
43- q , k , v , causal = causal , backend = "fa3" , o_dtype = torch .half
67+ q_fp16 , k_fp16 , v_fp16 , causal = causal , backend = "fa3"
68+ ),
69+ dry_run_time_ms = 100 ,
70+ repeat_time_ms = 1000 ,
71+ )
72+ )
73+
74+ # FP8 (fa3)
75+ fp8_ms = np .median (
76+ bench_gpu_time (
77+ lambda : flashinfer .single_prefill_with_kv_cache_return_lse (
78+ q_fp8 ,
79+ k_fp8 ,
80+ v_fp8 ,
81+ causal = causal ,
82+ backend = "fa3" ,
83+ scale_q = s_q ,
84+ scale_k = s_k ,
85+ scale_v = s_v ,
4486 ),
4587 dry_run_time_ms = 100 ,
4688 repeat_time_ms = 1000 ,
@@ -59,7 +101,222 @@ def flops(ms):
59101 )
60102
61103 print (
62- f"bench_single_prefill (seq_len={ seq_len } , num_heads={ num_heads } , causal={ causal } , head_dim={ head_dim } ), fa2-template: { flops (sm80_ms ):.3f} TFLOPs/s, fa3-template: { flops (sm90_ms ):.3f} TFLOPs/s, fa3-fp8: { flops (fp8_sm90_ms ):.3f} TFLOPs/s"
104+ f"bench_fp8_single_prefill (seq_len={ seq_len } , num_heads={ num_heads } , causal={ causal } , head_dim={ head_dim } ), "
105+ f"fp16: { flops (fp16_ms ):.3f} TFLOPs/s ({ fp16_ms :.3f} ms), "
106+ f"fp8: { flops (fp8_ms ):.3f} TFLOPs/s ({ fp8_ms :.3f} ms), "
107+ f"speedup: { fp16_ms / fp8_ms :.2f} x"
108+ )
109+
110+
111+ def bench_fp8_batch_ragged_prefill (
112+ batch_size , num_heads , seq_len , causal , head_dim , dtype = torch .float8_e4m3fn
113+ ):
114+ """Benchmark FP8 batch ragged prefill attention."""
115+ num_qo_heads = num_kv_heads = num_heads
116+ total_len = batch_size * seq_len
117+
118+ # Create FP16 tensors first
119+ q_fp16 = torch .randn (
120+ total_len , num_qo_heads , head_dim , dtype = torch .half , device = "cuda"
121+ )
122+ k_fp16 = torch .randn (
123+ total_len , num_kv_heads , head_dim , dtype = torch .half , device = "cuda"
124+ )
125+ v_fp16 = torch .randn (
126+ total_len , num_kv_heads , head_dim , dtype = torch .half , device = "cuda"
127+ )
128+
129+ # Quantize to FP8
130+ q_fp8 , s_q = per_head_symmetric_quant (q_fp16 , dtype )
131+ k_fp8 , s_k = per_head_symmetric_quant (k_fp16 , dtype )
132+ v_fp8 , s_v = per_head_symmetric_quant (v_fp16 , dtype )
133+
134+ qo_indptr = torch .arange (
135+ 0 , total_len + 1 , seq_len , dtype = torch .int32 , device = "cuda"
136+ )
137+ kv_indptr = torch .arange (
138+ 0 , total_len + 1 , seq_len , dtype = torch .int32 , device = "cuda"
139+ )
140+
141+ # FP16 wrapper
142+ fp16_wrapper = flashinfer .BatchPrefillWithRaggedKVCacheWrapper (
143+ torch .empty (256 * 1024 * 1024 , dtype = torch .uint8 , device = "cuda" ),
144+ kv_layout = "NHD" ,
145+ backend = "fa3" ,
146+ )
147+ fp16_wrapper .plan (
148+ qo_indptr , kv_indptr , num_qo_heads , num_kv_heads , head_dim , causal = causal
149+ )
150+
151+ # FP8 wrapper
152+ fp8_wrapper = flashinfer .BatchPrefillWithRaggedKVCacheWrapper (
153+ torch .empty (256 * 1024 * 1024 , dtype = torch .uint8 , device = "cuda" ),
154+ kv_layout = "NHD" ,
155+ backend = "fa3" ,
156+ )
157+ fp8_wrapper .plan (
158+ qo_indptr ,
159+ kv_indptr ,
160+ num_qo_heads ,
161+ num_kv_heads ,
162+ head_dim ,
163+ q_data_type = dtype ,
164+ kv_data_type = dtype ,
165+ o_data_type = torch .half ,
166+ causal = causal ,
167+ )
168+
169+ fp16_ms = np .median (
170+ bench_gpu_time (
171+ lambda : fp16_wrapper .run (q_fp16 , k_fp16 , v_fp16 ),
172+ dry_run_time_ms = 100 ,
173+ repeat_time_ms = 1000 ,
174+ )
175+ )
176+
177+ fp8_ms = np .median (
178+ bench_gpu_time (
179+ lambda : fp8_wrapper .run (q_fp8 , k_fp8 , v_fp8 , s_q , s_k , s_v ),
180+ dry_run_time_ms = 100 ,
181+ repeat_time_ms = 1000 ,
182+ )
183+ )
184+
185+ def flops (ms ):
186+ return attention_tflops_per_sec_with_actual_seq_lens (
187+ torch .full ((batch_size ,), seq_len ),
188+ torch .full ((batch_size ,), seq_len ),
189+ head_dim ,
190+ head_dim ,
191+ num_qo_heads ,
192+ causal ,
193+ ms ,
194+ )
195+
196+ print (
197+ f"bench_fp8_batch_ragged_prefill (batch_size={ batch_size } , num_heads={ num_heads } , seq_len={ seq_len } , causal={ causal } , head_dim={ head_dim } ), "
198+ f"fp16: { flops (fp16_ms ):.3f} TFLOPs/s ({ fp16_ms :.3f} ms), "
199+ f"fp8: { flops (fp8_ms ):.3f} TFLOPs/s ({ fp8_ms :.3f} ms), "
200+ f"speedup: { fp16_ms / fp8_ms :.2f} x"
201+ )
202+
203+
204+ def bench_fp8_batch_paged_prefill (
205+ page_size ,
206+ batch_size ,
207+ num_heads ,
208+ seq_len ,
209+ causal ,
210+ head_dim ,
211+ dtype = torch .float8_e4m3fn ,
212+ ):
213+ """Benchmark FP8 batch paged prefill attention."""
214+ num_qo_heads = num_kv_heads = num_heads
215+ total_qo_len = batch_size * seq_len
216+ num_pages = batch_size * seq_len // page_size
217+
218+ # Create FP16 tensors first
219+ q_fp16 = torch .randn (
220+ total_qo_len , num_qo_heads , head_dim , dtype = torch .half , device = "cuda"
221+ )
222+ # Paged KV cache: (num_pages, page_size, num_heads, head_dim)
223+ k_fp16 = torch .randn (
224+ num_pages , page_size , num_kv_heads , head_dim , dtype = torch .half , device = "cuda"
225+ )
226+ v_fp16 = torch .randn (
227+ num_pages , page_size , num_kv_heads , head_dim , dtype = torch .half , device = "cuda"
228+ )
229+
230+ # Quantize to FP8
231+ q_fp8 , s_q = per_head_symmetric_quant (q_fp16 , dtype )
232+ # For paged KV, reshape to (total_tokens, num_heads, head_dim) for quantization
233+ k_flat = k_fp16 .view (- 1 , num_kv_heads , head_dim )
234+ v_flat = v_fp16 .view (- 1 , num_kv_heads , head_dim )
235+ k_fp8_flat , s_k = per_head_symmetric_quant (k_flat , dtype )
236+ v_fp8_flat , s_v = per_head_symmetric_quant (v_flat , dtype )
237+ k_fp8 = k_fp8_flat .view (num_pages , page_size , num_kv_heads , head_dim )
238+ v_fp8 = v_fp8_flat .view (num_pages , page_size , num_kv_heads , head_dim )
239+
240+ qo_indptr = torch .arange (
241+ 0 , total_qo_len + 1 , seq_len , dtype = torch .int32 , device = "cuda"
242+ )
243+ kv_indptr = torch .arange (
244+ 0 , num_pages + 1 , seq_len // page_size , dtype = torch .int32 , device = "cuda"
245+ )
246+ kv_indices = torch .arange (0 , num_pages , dtype = torch .int32 , device = "cuda" )
247+ last_page_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" ) * page_size
248+
249+ # FP16 wrapper
250+ fp16_wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
251+ torch .empty (256 * 1024 * 1024 , dtype = torch .uint8 , device = "cuda" ),
252+ kv_layout = "NHD" ,
253+ backend = "fa3" ,
254+ )
255+ fp16_wrapper .plan (
256+ qo_indptr ,
257+ kv_indptr ,
258+ kv_indices ,
259+ last_page_len ,
260+ num_qo_heads ,
261+ num_kv_heads ,
262+ head_dim ,
263+ page_size ,
264+ causal = causal ,
265+ )
266+
267+ # FP8 wrapper
268+ fp8_wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
269+ torch .empty (256 * 1024 * 1024 , dtype = torch .uint8 , device = "cuda" ),
270+ kv_layout = "NHD" ,
271+ backend = "fa3" ,
272+ )
273+ fp8_wrapper .plan (
274+ qo_indptr ,
275+ kv_indptr ,
276+ kv_indices ,
277+ last_page_len ,
278+ num_qo_heads ,
279+ num_kv_heads ,
280+ head_dim ,
281+ page_size ,
282+ q_data_type = dtype ,
283+ kv_data_type = dtype ,
284+ o_data_type = torch .half ,
285+ causal = causal ,
286+ )
287+
288+ fp16_ms = np .median (
289+ bench_gpu_time (
290+ lambda : fp16_wrapper .run (q_fp16 , (k_fp16 , v_fp16 )),
291+ dry_run_time_ms = 100 ,
292+ repeat_time_ms = 1000 ,
293+ )
294+ )
295+
296+ fp8_ms = np .median (
297+ bench_gpu_time (
298+ lambda : fp8_wrapper .run (q_fp8 , (k_fp8 , v_fp8 ), s_q , s_k , s_v ),
299+ dry_run_time_ms = 100 ,
300+ repeat_time_ms = 1000 ,
301+ )
302+ )
303+
304+ def flops (ms ):
305+ return attention_tflops_per_sec_with_actual_seq_lens (
306+ torch .full ((batch_size ,), seq_len ),
307+ torch .full ((batch_size ,), seq_len ),
308+ head_dim ,
309+ head_dim ,
310+ num_qo_heads ,
311+ causal ,
312+ ms ,
313+ )
314+
315+ print (
316+ f"bench_fp8_batch_paged_prefill (page_size={ page_size } , batch_size={ batch_size } , num_heads={ num_heads } , seq_len={ seq_len } , causal={ causal } , head_dim={ head_dim } ), "
317+ f"fp16: { flops (fp16_ms ):.3f} TFLOPs/s ({ fp16_ms :.3f} ms), "
318+ f"fp8: { flops (fp8_ms ):.3f} TFLOPs/s ({ fp8_ms :.3f} ms), "
319+ f"speedup: { fp16_ms / fp8_ms :.2f} x"
63320 )
64321
65322
@@ -70,8 +327,30 @@ def flops(ms):
70327 print ("Current benchmark targets capability (9, 0). Returning..." )
71328 exit ()
72329
73- for seq_len in [4096 , 8192 , 16384 ]:
74- for num_heads in [24 , 32 ]:
75- for causal in [True , False ]:
76- for head_dim in [64 , 128 , 256 ]:
77- bench_single_prefill (seq_len , num_heads , causal , head_dim )
330+ # Skip single prefill for now due to compilation issues
331+ # print("=" * 80)
332+ # print("FP8 Single Prefill Benchmarks")
333+ # print("=" * 80)
334+ # for head_dim in [128, 256]:
335+ # for seq_len in [1024, 4096, 8192]:
336+ # bench_fp8_single_prefill(seq_len, 32, True, head_dim)
337+
338+ print ()
339+ print ("=" * 80 )
340+ print ("FP8 Batch Ragged Prefill Benchmarks" )
341+ print ("=" * 80 )
342+ for head_dim in [128 , 256 ]:
343+ bench_fp8_batch_ragged_prefill (128 , 32 , 1024 , True , head_dim )
344+ bench_fp8_batch_ragged_prefill (64 , 32 , 2048 , True , head_dim )
345+ bench_fp8_batch_ragged_prefill (32 , 32 , 4096 , True , head_dim )
346+ bench_fp8_batch_ragged_prefill (16 , 32 , 8192 , True , head_dim )
347+
348+ print ()
349+ print ("=" * 80 )
350+ print ("FP8 Batch Paged Prefill Benchmarks" )
351+ print ("=" * 80 )
352+ for head_dim in [128 , 256 ]:
353+ bench_fp8_batch_paged_prefill (16 , 128 , 32 , 1024 , True , head_dim )
354+ bench_fp8_batch_paged_prefill (16 , 64 , 32 , 2048 , True , head_dim )
355+ bench_fp8_batch_paged_prefill (16 , 32 , 32 , 4096 , True , head_dim )
356+ bench_fp8_batch_paged_prefill (16 , 16 , 32 , 8192 , True , head_dim )
0 commit comments