-
Notifications
You must be signed in to change notification settings - Fork 594
feat: Add flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache (fused RoPE + Q + KV cache, supports MLA/GQA/MHA) #2037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
3f7b7ed
(RoPE + Q fp8 + append kv_cache) fused kernel for MLA/GQA/MHA
kahyunnam 4040e9c
add a decode test
kahyunnam 20e123b
align fake op registration with the custom op name
kahyunnam 58a96f2
only 4D k_cache/v_cache
kahyunnam 034be6a
add check: GQA/MHA expects a V tensor,but got None.
kahyunnam 368dc2b
unsqueeze mla into dim 3 to match mha/gqa
kahyunnam 2556467
paramaterize page size 16,32 for testing
kahyunnam 401bab3
when no_rope_dim, optional None for nope tensors
kahyunnam 49b70ec
benchmarkign script revise use ncu
kahyunnam 3a97554
upd
yzh119 ec566bc
small fixes
kahyunnam 72bd59e
forgot to add utils
kahyunnam 0f87afd
fix cicd test?
kahyunnam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -420,3 +420,203 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope | |
| }); | ||
| }); | ||
| } | ||
|
|
||
| /*! | ||
| * TVM FFI binding for fused RoPE + quantization + paged KV cache append kernel | ||
| * | ||
| * Validates tensor shapes, dimensions, and data types, then dispatches to the templated | ||
| * RopeQuantizeAppendPagedKVCache CUDA kernel implementation. | ||
| */ | ||
| void rope_quantize_append_paged_kv_cache( | ||
| TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in, | ||
| TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache, | ||
| TensorView pos_ids, | ||
| // Paged cache tensors | ||
| TensorView k_cache, TensorView v_cache, TensorView ckv_cache, TensorView kpe_cache, | ||
| TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, TensorView positions, | ||
| int64_t kv_layout_code, int64_t page_size, double quant_scale_q, double quant_scale_kv, | ||
| bool interleave, bool enable_pdl) { | ||
| // Validate inputs | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_in); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_rope_in); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_in); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_nope_in); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_out); | ||
| CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_out); | ||
| CHECK_INPUT(cos_sin_cache); | ||
| CHECK_INPUT(pos_ids); | ||
| CHECK_INPUT(kv_indices); | ||
| CHECK_INPUT(kv_indptr); | ||
| CHECK_INPUT(batch_indices); | ||
| CHECK_INPUT(positions); | ||
|
|
||
| // Extract dimensions | ||
| uint32_t rope_dim = q_rope_in.size(-1); | ||
| uint32_t no_rope_dim = q_nope_in.size(-1); | ||
| uint32_t nnz = q_rope_in.size(0); | ||
| uint32_t num_qo_heads = q_rope_in.size(1); | ||
|
|
||
| // Validate dimensions | ||
| TVM_FFI_ICHECK_EQ(k_rope_in.size(-1), rope_dim); | ||
| TVM_FFI_ICHECK_EQ(k_nope_in.size(-1), no_rope_dim); | ||
| TVM_FFI_ICHECK_EQ(q_rope_out.size(-1), rope_dim); | ||
| TVM_FFI_ICHECK_EQ(q_nope_out.size(-1), no_rope_dim); | ||
| TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_rope_in.dtype()); | ||
| TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_in.dtype()); | ||
| TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_nope_in.dtype()); | ||
|
|
||
| // Validate input/output dtypes | ||
| TVM_FFI_ICHECK(q_rope_in.dtype() == dl_float16 || q_rope_in.dtype() == dl_bfloat16) | ||
| << "Input dtype must be float16 or bfloat16"; | ||
| TVM_FFI_ICHECK(q_rope_out.dtype() == dl_float8_e4m3fn || q_rope_out.dtype() == dl_float8_e5m2) | ||
| << "Output dtype must be float8_e4m3fn or float8_e5m2"; | ||
|
|
||
| // Q tensors are always 3D | ||
| CHECK_DIM(3, q_rope_in); | ||
| CHECK_DIM(3, q_nope_in); | ||
| CHECK_DIM(3, q_rope_out); | ||
| CHECK_DIM(3, q_nope_out); | ||
|
|
||
| // Detect architecture: MLA (2D K) vs GQA/MHA (3D K) | ||
| bool is_mla = (k_rope_in.ndim() == 2); | ||
| uint32_t num_kv_heads; | ||
| uint32_t batch_size = kv_indptr.size(0) - 1; | ||
| QKVLayout kv_layout = QKVLayout(kv_layout_code); | ||
|
|
||
| if (is_mla) { | ||
| // MLA: K tensors are 2D | ||
| CHECK_DIM(2, k_rope_in); | ||
| CHECK_DIM(2, k_nope_in); | ||
| num_kv_heads = 1; | ||
| // V can be empty for MLA | ||
| TVM_FFI_ICHECK(v_in.data_ptr() == nullptr || v_in.size(0) == 0) | ||
| << "MLA should not have V input (or it should be empty)"; | ||
| // Validate MLA cache tensors are provided | ||
| TVM_FFI_ICHECK(ckv_cache.data_ptr() != nullptr && kpe_cache.data_ptr() != nullptr) | ||
| << "MLA requires ckv_cache and kpe_cache"; | ||
| CHECK_DIM(3, ckv_cache); // (max_pages, page_size, ckv_dim) | ||
| CHECK_DIM(3, kpe_cache); // (max_pages, page_size, kpe_dim) | ||
| TVM_FFI_ICHECK_EQ(ckv_cache.size(2), no_rope_dim); | ||
| TVM_FFI_ICHECK_EQ(kpe_cache.size(2), rope_dim); | ||
| } else { | ||
| // GQA/MHA: K tensors are 3D | ||
| CHECK_DIM(3, k_rope_in); | ||
| CHECK_DIM(3, k_nope_in); | ||
| num_kv_heads = k_rope_in.size(1); | ||
| TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads); | ||
| // V is required for GQA/MHA | ||
| CHECK_DIM(3, v_in); | ||
| TVM_FFI_ICHECK_EQ(v_in.size(0), nnz); | ||
| TVM_FFI_ICHECK_EQ(v_in.size(1), num_kv_heads); | ||
| // Validate GQA/MHA cache tensors are provided | ||
| TVM_FFI_ICHECK(k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) | ||
| << "GQA/MHA requires k_cache and v_cache"; | ||
| // Cache must be 4D | ||
| CHECK_DIM(4, k_cache); | ||
| CHECK_DIM(4, v_cache); | ||
| } | ||
|
|
||
| // Extract Q strides | ||
| const uint32_t q_rope_in_stride_n = q_rope_in.stride(0); | ||
| const uint32_t q_rope_in_stride_h = q_rope_in.stride(1); | ||
| const uint32_t q_nope_in_stride_n = q_nope_in.stride(0); | ||
| const uint32_t q_nope_in_stride_h = q_nope_in.stride(1); | ||
| const uint32_t q_rope_out_stride_n = q_rope_out.stride(0); | ||
| const uint32_t q_rope_out_stride_h = q_rope_out.stride(1); | ||
| const uint32_t q_nope_out_stride_n = q_nope_out.stride(0); | ||
| const uint32_t q_nope_out_stride_h = q_nope_out.stride(1); | ||
|
|
||
| // Extract K strides (architecture dependent) | ||
| uint32_t k_rope_in_stride, k_nope_in_stride; | ||
| uint32_t k_rope_in_stride_h, k_nope_in_stride_h; | ||
| uint32_t v_in_stride = 0, v_in_stride_h = 0; | ||
|
|
||
| if (is_mla) { | ||
|
||
| // MLA: 2D K tensors | ||
| k_rope_in_stride = k_rope_in.stride(0); | ||
| k_nope_in_stride = k_nope_in.stride(0); | ||
| k_rope_in_stride_h = k_rope_in_stride; // Same as batch stride for 2D | ||
| k_nope_in_stride_h = k_nope_in_stride; | ||
| } else { | ||
| // GQA/MHA: 3D K tensors | ||
| k_rope_in_stride = k_rope_in.stride(0); | ||
| k_rope_in_stride_h = k_rope_in.stride(1); | ||
| k_nope_in_stride = k_nope_in.stride(0); | ||
| k_nope_in_stride_h = k_nope_in.stride(1); | ||
| v_in_stride = v_in.stride(0); | ||
| v_in_stride_h = v_in.stride(1); | ||
| } | ||
|
|
||
| cudaSetDevice(q_rope_in.device().device_id); | ||
| const cudaStream_t stream = get_stream(q_rope_in.device()); | ||
|
|
||
| DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] { | ||
| return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] { | ||
| cudaError_t status; | ||
|
|
||
| if (is_mla) { | ||
| // MLA: Construct paged_kv_mla_t struct | ||
| auto ckv_strides = ckv_cache.strides(); | ||
| auto kpe_strides = kpe_cache.strides(); | ||
|
|
||
| paged_kv_mla_t<c_quant_type, int32_t> paged_kv_mla( | ||
| page_size, no_rope_dim, rope_dim, batch_size, | ||
| static_cast<c_quant_type*>(ckv_cache.data_ptr()), ckv_strides.data(), | ||
| static_cast<c_quant_type*>(kpe_cache.data_ptr()), kpe_strides.data(), | ||
| static_cast<int32_t*>(kv_indices.data_ptr()), | ||
| static_cast<int32_t*>(kv_indptr.data_ptr()), | ||
| nullptr // last_page_len not needed for this kernel | ||
| ); | ||
|
|
||
| status = RopeQuantizeAppendPagedMLACache( | ||
| static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()), | ||
| static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()), | ||
| static_cast<c_quant_type*>(q_rope_out.data_ptr()), | ||
| static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv_mla, | ||
| static_cast<int32_t*>(batch_indices.data_ptr()), | ||
| static_cast<int32_t*>(positions.data_ptr()), | ||
| static_cast<float*>(cos_sin_cache.data_ptr()), | ||
| static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, | ||
| q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, | ||
| q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, | ||
| k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave, | ||
| enable_pdl, stream); | ||
|
|
||
| } else { | ||
| // GQA/MHA: Construct paged_kv_t struct | ||
| auto k_strides = k_cache.strides(); | ||
| auto v_strides = v_cache.strides(); | ||
| uint32_t head_dim = rope_dim + no_rope_dim; | ||
|
|
||
| paged_kv_t<c_quant_type, int32_t> paged_kv( | ||
| num_kv_heads, page_size, head_dim, batch_size, kv_layout, | ||
| static_cast<c_quant_type*>(k_cache.data_ptr()), | ||
| static_cast<c_quant_type*>(v_cache.data_ptr()), k_strides.data(), | ||
| static_cast<int32_t*>(kv_indices.data_ptr()), | ||
| static_cast<int32_t*>(kv_indptr.data_ptr()), | ||
| nullptr // last_page_len not needed for this kernel | ||
| ); | ||
|
|
||
| status = RopeQuantizeAppendPagedKVCache( | ||
| static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()), | ||
| static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()), | ||
| static_cast<c_type*>(v_in.data_ptr()), | ||
| static_cast<c_quant_type*>(q_rope_out.data_ptr()), | ||
| static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv, | ||
| static_cast<int32_t*>(batch_indices.data_ptr()), | ||
| static_cast<int32_t*>(positions.data_ptr()), | ||
| static_cast<float*>(cos_sin_cache.data_ptr()), | ||
| static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, | ||
| no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, | ||
| q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, | ||
| q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, | ||
| k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv, | ||
| interleave, enable_pdl, stream); | ||
kahyunnam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| TVM_FFI_ICHECK(status == cudaSuccess) | ||
| << "RopeQuantizeAppendPagedKVCache failed with error code " << cudaGetErrorString(status); | ||
| return true; | ||
| }); | ||
| }); | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we unify the two branches?
In C++ side, we assume K tensor is 3D.
At python side, if we found K tensor is 2D, we unsqueeze its dimension 1.