Skip to content

Commit f89ae64

Browse files
committed
upd
1 parent 876c386 commit f89ae64

File tree

3 files changed

+397
-64
lines changed

3 files changed

+397
-64
lines changed
Lines changed: 310 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
117
import numpy as np
218
import torch
319

@@ -8,39 +24,65 @@
824
)
925

1026

11-
def bench_single_prefill(seq_len, num_heads, causal, head_dim):
12-
num_qo_heads = num_kv_heads = num_heads
13-
q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda")
14-
k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
15-
v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
16-
17-
sm80_ms, sm90_ms = (
18-
np.median(
19-
bench_gpu_time(
20-
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
21-
q, k, v, causal=causal, backend=backend
22-
),
23-
dry_run_time_ms=100,
24-
repeat_time_ms=1000,
25-
)
26-
)
27-
for backend in ["fa2", "fa3"]
27+
def per_head_symmetric_quant(x, quant_dtype):
28+
"""Per-head symmetric quantization to FP8."""
29+
o_min_val, o_max_val = (
30+
(-448.0, 448.0) if quant_dtype == torch.float8_e4m3fn else (-57344, 57344)
31+
)
32+
x_max_val = x.abs().amax(dim=(0, 2)).to(dtype=torch.float32)
33+
s_out = torch.clamp(x_max_val / o_max_val, min=1e-6)
34+
s_out_broadcast = s_out.view(1, -1, 1)
35+
q_x_out = torch.clamp(x / s_out_broadcast, min=o_min_val, max=o_max_val).to(
36+
dtype=quant_dtype
2837
)
38+
return q_x_out, s_out
2939

30-
q = torch.randn(
40+
41+
def bench_fp8_single_prefill(
42+
seq_len, num_heads, causal, head_dim, dtype=torch.float8_e4m3fn
43+
):
44+
"""Benchmark FP8 single prefill attention."""
45+
num_qo_heads = num_kv_heads = num_heads
46+
47+
# Create FP16 tensors first, then quantize
48+
q_fp16 = torch.randn(
3149
seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda"
32-
).to(dtype=torch.float8_e4m3fn)
33-
k = torch.randn(
50+
)
51+
k_fp16 = torch.randn(
3452
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
35-
).to(dtype=torch.float8_e4m3fn)
36-
v = torch.randn(
53+
)
54+
v_fp16 = torch.randn(
3755
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
38-
).to(dtype=torch.float8_e4m3fn)
56+
)
57+
58+
# Quantize to FP8
59+
q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype)
60+
k_fp8, s_k = per_head_symmetric_quant(k_fp16, dtype)
61+
v_fp8, s_v = per_head_symmetric_quant(v_fp16, dtype)
3962

40-
fp8_sm90_ms = np.median(
63+
# FP16 baseline (fa3)
64+
fp16_ms = np.median(
4165
bench_gpu_time(
4266
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
43-
q, k, v, causal=causal, backend="fa3", o_dtype=torch.half
67+
q_fp16, k_fp16, v_fp16, causal=causal, backend="fa3"
68+
),
69+
dry_run_time_ms=100,
70+
repeat_time_ms=1000,
71+
)
72+
)
73+
74+
# FP8 (fa3)
75+
fp8_ms = np.median(
76+
bench_gpu_time(
77+
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
78+
q_fp8,
79+
k_fp8,
80+
v_fp8,
81+
causal=causal,
82+
backend="fa3",
83+
scale_q=s_q,
84+
scale_k=s_k,
85+
scale_v=s_v,
4486
),
4587
dry_run_time_ms=100,
4688
repeat_time_ms=1000,
@@ -59,7 +101,222 @@ def flops(ms):
59101
)
60102

61103
print(
62-
f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s"
104+
f"bench_fp8_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), "
105+
f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), "
106+
f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), "
107+
f"speedup: {fp16_ms / fp8_ms:.2f}x"
108+
)
109+
110+
111+
def bench_fp8_batch_ragged_prefill(
112+
batch_size, num_heads, seq_len, causal, head_dim, dtype=torch.float8_e4m3fn
113+
):
114+
"""Benchmark FP8 batch ragged prefill attention."""
115+
num_qo_heads = num_kv_heads = num_heads
116+
total_len = batch_size * seq_len
117+
118+
# Create FP16 tensors first
119+
q_fp16 = torch.randn(
120+
total_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda"
121+
)
122+
k_fp16 = torch.randn(
123+
total_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
124+
)
125+
v_fp16 = torch.randn(
126+
total_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
127+
)
128+
129+
# Quantize to FP8
130+
q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype)
131+
k_fp8, s_k = per_head_symmetric_quant(k_fp16, dtype)
132+
v_fp8, s_v = per_head_symmetric_quant(v_fp16, dtype)
133+
134+
qo_indptr = torch.arange(
135+
0, total_len + 1, seq_len, dtype=torch.int32, device="cuda"
136+
)
137+
kv_indptr = torch.arange(
138+
0, total_len + 1, seq_len, dtype=torch.int32, device="cuda"
139+
)
140+
141+
# FP16 wrapper
142+
fp16_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
143+
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"),
144+
kv_layout="NHD",
145+
backend="fa3",
146+
)
147+
fp16_wrapper.plan(
148+
qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=causal
149+
)
150+
151+
# FP8 wrapper
152+
fp8_wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
153+
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"),
154+
kv_layout="NHD",
155+
backend="fa3",
156+
)
157+
fp8_wrapper.plan(
158+
qo_indptr,
159+
kv_indptr,
160+
num_qo_heads,
161+
num_kv_heads,
162+
head_dim,
163+
q_data_type=dtype,
164+
kv_data_type=dtype,
165+
o_data_type=torch.half,
166+
causal=causal,
167+
)
168+
169+
fp16_ms = np.median(
170+
bench_gpu_time(
171+
lambda: fp16_wrapper.run(q_fp16, k_fp16, v_fp16),
172+
dry_run_time_ms=100,
173+
repeat_time_ms=1000,
174+
)
175+
)
176+
177+
fp8_ms = np.median(
178+
bench_gpu_time(
179+
lambda: fp8_wrapper.run(q_fp8, k_fp8, v_fp8, s_q, s_k, s_v),
180+
dry_run_time_ms=100,
181+
repeat_time_ms=1000,
182+
)
183+
)
184+
185+
def flops(ms):
186+
return attention_tflops_per_sec_with_actual_seq_lens(
187+
torch.full((batch_size,), seq_len),
188+
torch.full((batch_size,), seq_len),
189+
head_dim,
190+
head_dim,
191+
num_qo_heads,
192+
causal,
193+
ms,
194+
)
195+
196+
print(
197+
f"bench_fp8_batch_ragged_prefill (batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), "
198+
f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), "
199+
f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), "
200+
f"speedup: {fp16_ms / fp8_ms:.2f}x"
201+
)
202+
203+
204+
def bench_fp8_batch_paged_prefill(
205+
page_size,
206+
batch_size,
207+
num_heads,
208+
seq_len,
209+
causal,
210+
head_dim,
211+
dtype=torch.float8_e4m3fn,
212+
):
213+
"""Benchmark FP8 batch paged prefill attention."""
214+
num_qo_heads = num_kv_heads = num_heads
215+
total_qo_len = batch_size * seq_len
216+
num_pages = batch_size * seq_len // page_size
217+
218+
# Create FP16 tensors first
219+
q_fp16 = torch.randn(
220+
total_qo_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda"
221+
)
222+
# Paged KV cache: (num_pages, page_size, num_heads, head_dim)
223+
k_fp16 = torch.randn(
224+
num_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
225+
)
226+
v_fp16 = torch.randn(
227+
num_pages, page_size, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
228+
)
229+
230+
# Quantize to FP8
231+
q_fp8, s_q = per_head_symmetric_quant(q_fp16, dtype)
232+
# For paged KV, reshape to (total_tokens, num_heads, head_dim) for quantization
233+
k_flat = k_fp16.view(-1, num_kv_heads, head_dim)
234+
v_flat = v_fp16.view(-1, num_kv_heads, head_dim)
235+
k_fp8_flat, s_k = per_head_symmetric_quant(k_flat, dtype)
236+
v_fp8_flat, s_v = per_head_symmetric_quant(v_flat, dtype)
237+
k_fp8 = k_fp8_flat.view(num_pages, page_size, num_kv_heads, head_dim)
238+
v_fp8 = v_fp8_flat.view(num_pages, page_size, num_kv_heads, head_dim)
239+
240+
qo_indptr = torch.arange(
241+
0, total_qo_len + 1, seq_len, dtype=torch.int32, device="cuda"
242+
)
243+
kv_indptr = torch.arange(
244+
0, num_pages + 1, seq_len // page_size, dtype=torch.int32, device="cuda"
245+
)
246+
kv_indices = torch.arange(0, num_pages, dtype=torch.int32, device="cuda")
247+
last_page_len = torch.ones(batch_size, dtype=torch.int32, device="cuda") * page_size
248+
249+
# FP16 wrapper
250+
fp16_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
251+
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"),
252+
kv_layout="NHD",
253+
backend="fa3",
254+
)
255+
fp16_wrapper.plan(
256+
qo_indptr,
257+
kv_indptr,
258+
kv_indices,
259+
last_page_len,
260+
num_qo_heads,
261+
num_kv_heads,
262+
head_dim,
263+
page_size,
264+
causal=causal,
265+
)
266+
267+
# FP8 wrapper
268+
fp8_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
269+
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device="cuda"),
270+
kv_layout="NHD",
271+
backend="fa3",
272+
)
273+
fp8_wrapper.plan(
274+
qo_indptr,
275+
kv_indptr,
276+
kv_indices,
277+
last_page_len,
278+
num_qo_heads,
279+
num_kv_heads,
280+
head_dim,
281+
page_size,
282+
q_data_type=dtype,
283+
kv_data_type=dtype,
284+
o_data_type=torch.half,
285+
causal=causal,
286+
)
287+
288+
fp16_ms = np.median(
289+
bench_gpu_time(
290+
lambda: fp16_wrapper.run(q_fp16, (k_fp16, v_fp16)),
291+
dry_run_time_ms=100,
292+
repeat_time_ms=1000,
293+
)
294+
)
295+
296+
fp8_ms = np.median(
297+
bench_gpu_time(
298+
lambda: fp8_wrapper.run(q_fp8, (k_fp8, v_fp8), s_q, s_k, s_v),
299+
dry_run_time_ms=100,
300+
repeat_time_ms=1000,
301+
)
302+
)
303+
304+
def flops(ms):
305+
return attention_tflops_per_sec_with_actual_seq_lens(
306+
torch.full((batch_size,), seq_len),
307+
torch.full((batch_size,), seq_len),
308+
head_dim,
309+
head_dim,
310+
num_qo_heads,
311+
causal,
312+
ms,
313+
)
314+
315+
print(
316+
f"bench_fp8_batch_paged_prefill (page_size={page_size}, batch_size={batch_size}, num_heads={num_heads}, seq_len={seq_len}, causal={causal}, head_dim={head_dim}), "
317+
f"fp16: {flops(fp16_ms):.3f} TFLOPs/s ({fp16_ms:.3f}ms), "
318+
f"fp8: {flops(fp8_ms):.3f} TFLOPs/s ({fp8_ms:.3f}ms), "
319+
f"speedup: {fp16_ms / fp8_ms:.2f}x"
63320
)
64321

65322

@@ -70,8 +327,30 @@ def flops(ms):
70327
print("Current benchmark targets capability (9, 0). Returning...")
71328
exit()
72329

73-
for seq_len in [4096, 8192, 16384]:
74-
for num_heads in [24, 32]:
75-
for causal in [True, False]:
76-
for head_dim in [64, 128, 256]:
77-
bench_single_prefill(seq_len, num_heads, causal, head_dim)
330+
# Skip single prefill for now due to compilation issues
331+
# print("=" * 80)
332+
# print("FP8 Single Prefill Benchmarks")
333+
# print("=" * 80)
334+
# for head_dim in [128, 256]:
335+
# for seq_len in [1024, 4096, 8192]:
336+
# bench_fp8_single_prefill(seq_len, 32, True, head_dim)
337+
338+
print()
339+
print("=" * 80)
340+
print("FP8 Batch Ragged Prefill Benchmarks")
341+
print("=" * 80)
342+
for head_dim in [128, 256]:
343+
bench_fp8_batch_ragged_prefill(128, 32, 1024, True, head_dim)
344+
bench_fp8_batch_ragged_prefill(64, 32, 2048, True, head_dim)
345+
bench_fp8_batch_ragged_prefill(32, 32, 4096, True, head_dim)
346+
bench_fp8_batch_ragged_prefill(16, 32, 8192, True, head_dim)
347+
348+
print()
349+
print("=" * 80)
350+
print("FP8 Batch Paged Prefill Benchmarks")
351+
print("=" * 80)
352+
for head_dim in [128, 256]:
353+
bench_fp8_batch_paged_prefill(16, 128, 32, 1024, True, head_dim)
354+
bench_fp8_batch_paged_prefill(16, 64, 32, 2048, True, head_dim)
355+
bench_fp8_batch_paged_prefill(16, 32, 32, 4096, True, head_dim)
356+
bench_fp8_batch_paged_prefill(16, 16, 32, 8192, True, head_dim)

0 commit comments

Comments
 (0)