Commit 636a3ab
[Feature] Support batch prefill for POD Attention (#2079)
<!-- .github/pull_request_template.md -->
Co-authored-by: @Edenzzzz
## 📌 Description
Fixes #1022. Unlike
#1231, this splits the
inputs into separate prefill and decode inputs. It probably should be
possible to automatically handle this splitting in Python so you can
simply just provide a single batch of requests?
To run the benchmark for this run: `python
benchmarks/bench_mixed_attention.py`
Performance:
===== Benchmark 1: (kv_len, qo_len) set =====
Prefill = 2 requests, 2048 Q len, 2048 KV len
Decode = 128 requests, 2048 KV len
Elapsed time (Batched Prefill): 0.65 ms
Elapsed time (Batched POD Attention): 0.46 ms
Elapsed time (Persistent BatchAttention): 0.56 ms
**Batch POD speedup over Persistent BatchAttention: 1.22x**
===== Benchmark 2: (kv_len, qo_len) set =====
Prefill = 1 request, 2048 Q len, 2048 KV len
Decode = 128 requests, 2048 KV len
Elapsed time (Batched Prefill): 0.55 ms
Elapsed time (Batched POD Attention): 0.41 ms
Elapsed time (POD Attention): 0.41 ms
Elapsed time (Sequential two kernels): 0.51 ms
Elapsed time (Persistent BatchAttention): 0.45 ms
**Batch POD speedup over Persistent BatchAttention: 1.11x**
===== Benchmark 3: (kv_len, qo_len) set =====
Prefill = 1 request, 4096 Q len, 4096 KV len
Decode = 128 requests, 4096 KV len
Elapsed time (Batched Prefill): 1.27 ms
Elapsed time (Batched POD Attention): 0.86 ms
Elapsed time (POD Attention): 0.82 ms
Elapsed time (Sequential two kernels): 1.15 ms
Elapsed time (Persistent BatchAttention): 1.08 ms
Batch POD speedup over Persistent BatchAttention: 1.26x
===== Benchmark 4: (kv_len, qo_len) set =====
Prefill = 1 request, 4096 Q len, 4096 KV len
Decode = 128 requests, 8192 KV len
Elapsed time (Batched Prefill): 2.15 ms
Elapsed time (Batched POD Attention): 1.52 ms
Elapsed time (POD Attention): 1.54 ms
Elapsed time (Sequential two kernels): 1.82 ms
Elapsed time (Persistent BatchAttention): 1.76 ms
**Batch POD speedup over Persistent BatchAttention: 1.16x**
===== Benchmark 5: (kv_len, qo_len) set =====
Prefill = 1 request, 6000 Q len, 7000 KV len
Decode = 128 requests, 8192 KV len
Elapsed time (Batched Prefill): 2.86 ms
Elapsed time (Batched POD Attention): 2.03 ms
Elapsed time (POD Attention): 1.95 ms
Elapsed time (Sequential two kernels): 2.52 ms
Elapsed time (Persistent BatchAttention): 2.45 ms
**Batch POD speedup over Persistent BatchAttention: 1.20x**
## 🔍 Related Issues
<!-- Link any related issues here -->
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added a batched prefill+decode attention path with a public
batch-oriented POD wrapper and JIT module export.
* **Performance**
* Benchmarks extended to include batched-path timings, memory bandwidth,
elapsed-time and comparative speedup metrics across expanded
prefill/decode scenarios.
* **API**
* Runtime binding for batched KV‑cache execution added; planning APIs
now accept an optional colocated-CTA parameter that influences
scheduling.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Co-authored-by: Aditya K Kamath <akamath1997@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>1 parent 9a79b78 commit 636a3ab
File tree
18 files changed
+1725
-16
lines changed- benchmarks
- csrc
- flashinfer
- jit
- attention
- include/flashinfer/attention
18 files changed
+1725
-16
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | | - | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
27 | 30 | | |
28 | 31 | | |
29 | 32 | | |
30 | 33 | | |
31 | 34 | | |
32 | 35 | | |
33 | 36 | | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
34 | 45 | | |
35 | 46 | | |
36 | 47 | | |
| |||
46 | 57 | | |
47 | 58 | | |
48 | 59 | | |
49 | | - | |
| 60 | + | |
50 | 61 | | |
51 | 62 | | |
52 | 63 | | |
| |||
90 | 101 | | |
91 | 102 | | |
92 | 103 | | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
93 | 163 | | |
| 164 | + | |
94 | 165 | | |
95 | 166 | | |
96 | 167 | | |
| |||
127 | 198 | | |
128 | 199 | | |
129 | 200 | | |
130 | | - | |
| 201 | + | |
131 | 202 | | |
132 | 203 | | |
133 | 204 | | |
| |||
177 | 248 | | |
178 | 249 | | |
179 | 250 | | |
| 251 | + | |
180 | 252 | | |
181 | 253 | | |
182 | 254 | | |
183 | 255 | | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
184 | 260 | | |
185 | 261 | | |
186 | 262 | | |
| |||
189 | 265 | | |
190 | 266 | | |
191 | 267 | | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
192 | 272 | | |
193 | 273 | | |
194 | 274 | | |
| |||
207 | 287 | | |
208 | 288 | | |
209 | 289 | | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
214 | 302 | | |
215 | 303 | | |
216 | 304 | | |
| |||
0 commit comments