Skip to content

Commit 636a3ab

Browse files
AKKamathAditya K Kamathgemini-code-assist[bot]Edenzzzz
authored
[Feature] Support batch prefill for POD Attention (#2079)
<!-- .github/pull_request_template.md --> Co-authored-by: @Edenzzzz ## 📌 Description Fixes #1022. Unlike #1231, this splits the inputs into separate prefill and decode inputs. It probably should be possible to automatically handle this splitting in Python so you can simply just provide a single batch of requests? To run the benchmark for this run: `python benchmarks/bench_mixed_attention.py` Performance: ===== Benchmark 1: (kv_len, qo_len) set ===== Prefill = 2 requests, 2048 Q len, 2048 KV len Decode = 128 requests, 2048 KV len Elapsed time (Batched Prefill): 0.65 ms Elapsed time (Batched POD Attention): 0.46 ms Elapsed time (Persistent BatchAttention): 0.56 ms **Batch POD speedup over Persistent BatchAttention: 1.22x** ===== Benchmark 2: (kv_len, qo_len) set ===== Prefill = 1 request, 2048 Q len, 2048 KV len Decode = 128 requests, 2048 KV len Elapsed time (Batched Prefill): 0.55 ms Elapsed time (Batched POD Attention): 0.41 ms Elapsed time (POD Attention): 0.41 ms Elapsed time (Sequential two kernels): 0.51 ms Elapsed time (Persistent BatchAttention): 0.45 ms **Batch POD speedup over Persistent BatchAttention: 1.11x** ===== Benchmark 3: (kv_len, qo_len) set ===== Prefill = 1 request, 4096 Q len, 4096 KV len Decode = 128 requests, 4096 KV len Elapsed time (Batched Prefill): 1.27 ms Elapsed time (Batched POD Attention): 0.86 ms Elapsed time (POD Attention): 0.82 ms Elapsed time (Sequential two kernels): 1.15 ms Elapsed time (Persistent BatchAttention): 1.08 ms Batch POD speedup over Persistent BatchAttention: 1.26x ===== Benchmark 4: (kv_len, qo_len) set ===== Prefill = 1 request, 4096 Q len, 4096 KV len Decode = 128 requests, 8192 KV len Elapsed time (Batched Prefill): 2.15 ms Elapsed time (Batched POD Attention): 1.52 ms Elapsed time (POD Attention): 1.54 ms Elapsed time (Sequential two kernels): 1.82 ms Elapsed time (Persistent BatchAttention): 1.76 ms **Batch POD speedup over Persistent BatchAttention: 1.16x** ===== Benchmark 5: (kv_len, qo_len) set ===== Prefill = 1 request, 6000 Q len, 7000 KV len Decode = 128 requests, 8192 KV len Elapsed time (Batched Prefill): 2.86 ms Elapsed time (Batched POD Attention): 2.03 ms Elapsed time (POD Attention): 1.95 ms Elapsed time (Sequential two kernels): 2.52 ms Elapsed time (Persistent BatchAttention): 2.45 ms **Batch POD speedup over Persistent BatchAttention: 1.20x** ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a batched prefill+decode attention path with a public batch-oriented POD wrapper and JIT module export. * **Performance** * Benchmarks extended to include batched-path timings, memory bandwidth, elapsed-time and comparative speedup metrics across expanded prefill/decode scenarios. * **API** * Runtime binding for batched KV‑cache execution added; planning APIs now accept an optional colocated-CTA parameter that influences scheduling. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Aditya K Kamath <akamath1997@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu>
1 parent 9a79b78 commit 636a3ab

18 files changed

+1725
-16
lines changed

benchmarks/bench_mixed_attention.py

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,25 @@ def run_bench(
2323
q_lens = torch.tensor(d_qo_lens + p_qo_lens, dtype=torch.int32)
2424

2525
seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int()
26-
d_seq_lens_blocks = (
26+
p_seq_lens_blocks = torch.ceil(
27+
torch.tensor(p_kv_lens, dtype=torch.int32) / page_block_size
28+
).int()
29+
d_seq_lens_blocks = torch.ceil(
2730
torch.tensor(d_kv_lens, dtype=torch.int32) / page_block_size
2831
).int()
2932

3033
q_indptr = torch.cat([torch.tensor([0]), torch.cumsum(q_lens, 0)], dim=0).int()
3134
kv_indptr = torch.cat(
3235
[torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0
3336
).int()
37+
38+
p_q_indptr = torch.cat(
39+
[torch.tensor([0]), torch.cumsum(torch.tensor(p_qo_lens), 0)], dim=0
40+
).int()
41+
p_kv_indptr = torch.cat(
42+
[torch.tensor([0]), torch.cumsum(p_seq_lens_blocks, 0)], dim=0
43+
).int()
44+
3445
d_q_indptr = torch.cat(
3546
[torch.tensor([0]), torch.cumsum(torch.tensor(d_qo_lens), 0)], dim=0
3647
).int()
@@ -46,7 +57,7 @@ def run_bench(
4657
device, dtype=torch.bfloat16
4758
)
4859

49-
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
60+
workspace_buffer = torch.empty(156 * 1024 * 1024, dtype=torch.uint8, device=device)
5061
kv_layout = "NHD"
5162

5263
wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
@@ -90,7 +101,67 @@ def run_bench(
90101
o_persistent, _ = wrapper_persistent.run(q, kv_data)
91102
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
92103
ms_persistent = np.mean(measurements_persistent)
104+
105+
# Batched POD Attention
106+
q_d = q[: d_q_indptr[-1]]
107+
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
108+
q_p = q[d_q_indptr[-1] :]
109+
kv_p = kv_data[d_kv_indptr[-1] :].unbind(1)
110+
kv_indices_d = torch.arange(0, d_kv_indptr[-1], device=device, dtype=torch.int32)
111+
kv_indices_p = torch.arange(0, p_kv_indptr[-1], device=device, dtype=torch.int32)
112+
113+
last_page_len_d = (d_seq_lens_blocks - 1) % page_block_size + 1
114+
last_page_len_p = (p_seq_lens_blocks - 1) % page_block_size + 1
115+
wrapper_pod = flashinfer.BatchPODWithPagedKVCacheWrapper(
116+
workspace_buffer,
117+
kv_layout=kv_layout,
118+
)
119+
120+
wrapper_pod.plan(
121+
# Prefill params
122+
p_q_indptr.to(device),
123+
p_kv_indptr.to(device),
124+
kv_indices_p.to(device),
125+
last_page_len_p,
126+
# Decode params
127+
d_q_indptr.to(device),
128+
d_kv_indptr.to(device),
129+
kv_indices_d.to(device),
130+
last_page_len_d,
131+
# Common params
132+
num_qo_heads=num_qo_heads,
133+
num_kv_heads=num_kv_heads,
134+
head_dim=head_dim,
135+
page_size=page_block_size,
136+
q_data_type=torch.bfloat16,
137+
kv_data_type=torch.bfloat16,
138+
)
139+
o_p_batch, o_d_batch = wrapper_pod.run(
140+
q_p,
141+
kv_p,
142+
q_d,
143+
kv_d,
144+
causal_p=causal,
145+
)
146+
o_batch_pod = torch.cat([o_d_batch, o_p_batch], dim=0)
147+
148+
# Verify output matches
149+
torch.testing.assert_close(
150+
o_batch_pod, o, rtol=4e-3, atol=4e-3, msg="Batch POD-Attention decode mismatch!"
151+
)
152+
measurements = bench_gpu_time(
153+
lambda: wrapper_pod.run(
154+
q_p,
155+
kv_p,
156+
q_d,
157+
kv_d,
158+
causal_p=causal,
159+
)
160+
)
161+
ms_batch_pod = np.median(measurements)
162+
93163
if len(p_kv_lens) == 1:
164+
# Single POD attention
94165
q_d = q[: d_q_indptr[-1]]
95166
kv_d = kv_data[: d_kv_indptr[-1]].unbind(1)
96167
q_p = q[d_q_indptr[-1] :]
@@ -127,7 +198,7 @@ def run_bench(
127198
o_pod = torch.cat([o_d, o_p], dim=0)
128199
# Verify output matches
129200
torch.testing.assert_close(
130-
o, o_pod, rtol=1e-3, atol=1e-3, msg="POD-Attention output mismatch!"
201+
o, o_pod, rtol=4e-3, atol=4e-3, msg="POD-Attention output mismatch!"
131202
)
132203
measurements = bench_gpu_time(
133204
lambda: wrapper_pod.run(
@@ -177,10 +248,15 @@ def _run_single_prefill():
177248
ms_seq_two_kernels = ms_prefill + ms_decode
178249

179250
print(f"Elapsed time (Batched Prefill): {ms_old:.2f} ms")
251+
print(f"Elapsed time (Batched POD Attention): {ms_batch_pod:.2f} ms")
180252
if len(p_kv_lens) == 1:
181253
print(f"Elapsed time (POD Attention): {ms_pod:.2f} ms")
182254
print(f"Elapsed time (Sequential two kernels): {ms_seq_two_kernels:.2f} ms")
183255
print(f"Elapsed time (Persistent BatchAttention): {ms_persistent:.2f} ms")
256+
print(
257+
f"Batch POD speedup over Persistent BatchAttention: {ms_persistent / ms_batch_pod:.2f}x"
258+
)
259+
184260
total_bytes = (
185261
q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
186262
)
@@ -189,6 +265,10 @@ def _run_single_prefill():
189265
bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3) / (1024**3)
190266

191267
print(f"Memory bandwidth (Batched Prefill): {bandwidth_old_gb_s:.2f} GB/s")
268+
bandwidth_batch_pod_gb_s = total_bytes / (ms_batch_pod * 1e-3) / (1024**3)
269+
print(
270+
f"Memory bandwidth (Batched POD Attention): {bandwidth_batch_pod_gb_s:.2f} GB/s"
271+
)
192272
if len(p_kv_lens) == 1:
193273
bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3) / (1024**3)
194274
print(f"Memory bandwidth (POD Attention): {bandwidth_pod_gb_s:.2f} GB/s")
@@ -207,10 +287,18 @@ def _run_single_prefill():
207287
torch.random.manual_seed(42)
208288

209289
# Irregular sequence lengths for prefill and decode
210-
d_q_len_configs = [[1] * 128, [1] * 128, [1] * 128, [1] * 128]
211-
d_kv_len_configs = [[2048] * 128, [4096] * 128, [8192] * 128, [8192] * 128]
212-
p_q_configs = [[2048], [4096], [4096], [6000]]
213-
p_kv_configs = [[2048], [4096], [4096], [7000]]
290+
d_q_len_configs = [[1] * 128] * 7
291+
d_kv_len_configs = [
292+
[2048] * 128,
293+
[2048] * 128,
294+
[2048] * 128,
295+
[2048] * 128,
296+
[4096] * 128,
297+
[8192] * 128,
298+
[8192] * 128,
299+
]
300+
p_q_configs = [[512], [1536], [2048] * 2, [2048], [4096], [4096], [6000]]
301+
p_kv_configs = [[512], [1536], [2048] * 2, [2048], [4096], [4096], [7000]]
214302

215303
page_block_size = 1
216304
num_kv_heads = 8

0 commit comments

Comments
 (0)