From 3f7b7edd57338c8734c16c445af8be495a20dbd9 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Mon, 3 Nov 2025 15:49:11 -0800 Subject: [PATCH 01/13] (RoPE + Q fp8 + append kv_cache) fused kernel for MLA/GQA/MHA --- csrc/flashinfer_rope_binding.cu | 13 ++ csrc/rope.cu | 200 +++++++++++++++++ flashinfer/rope.py | 365 ++++++++++++++++++++++++++++++++ include/flashinfer/pos_enc.cuh | 365 ++++++++++++++++++++++++++++++++ tests/attention/test_rope.py | 314 +++++++++++++++++++++++++++ 5 files changed, 1257 insertions(+) diff --git a/csrc/flashinfer_rope_binding.cu b/csrc/flashinfer_rope_binding.cu index 23124064d8..e58deda5c0 100644 --- a/csrc/flashinfer_rope_binding.cu +++ b/csrc/flashinfer_rope_binding.cu @@ -45,9 +45,22 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope TensorView pos_ids, double quant_scale_q, double quant_scale_kv, bool interleave, bool enable_pdl); +// Fused RoPE + Quantize + Append Paged KV Cache (MLA/GQA/MHA) +void rope_quantize_append_paged_kv_cache( + TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in, + TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache, + TensorView pos_ids, + // Paged cache tensors + TensorView k_cache, TensorView v_cache, TensorView ckv_cache, TensorView kpe_cache, + TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, TensorView positions, + int64_t kv_layout_code, int64_t page_size, double quant_scale_q, double quant_scale_kv, + bool interleave, bool enable_pdl); + TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids, apply_rope_pos_ids); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope_pos_ids, apply_llama31_rope_pos_ids); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids_cos_sin_cache, apply_rope_pos_ids_cos_sin_cache); TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize_append_paged_kv_cache, + rope_quantize_append_paged_kv_cache); diff --git a/csrc/rope.cu b/csrc/rope.cu index 78cdcad405..4da3d49125 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -420,3 +420,203 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope }); }); } + +/*! + * TVM FFI binding for fused RoPE + quantization + paged KV cache append kernel + * + * Validates tensor shapes, dimensions, and data types, then dispatches to the templated + * RopeQuantizeAppendPagedKVCache CUDA kernel implementation. + */ +void rope_quantize_append_paged_kv_cache( + TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in, + TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache, + TensorView pos_ids, + // Paged cache tensors + TensorView k_cache, TensorView v_cache, TensorView ckv_cache, TensorView kpe_cache, + TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, TensorView positions, + int64_t kv_layout_code, int64_t page_size, double quant_scale_q, double quant_scale_kv, + bool interleave, bool enable_pdl) { + // Validate inputs + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_rope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_nope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_out); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_out); + CHECK_INPUT(cos_sin_cache); + CHECK_INPUT(pos_ids); + CHECK_INPUT(kv_indices); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(batch_indices); + CHECK_INPUT(positions); + + // Extract dimensions + uint32_t rope_dim = q_rope_in.size(-1); + uint32_t no_rope_dim = q_nope_in.size(-1); + uint32_t nnz = q_rope_in.size(0); + uint32_t num_qo_heads = q_rope_in.size(1); + + // Validate dimensions + TVM_FFI_ICHECK_EQ(k_rope_in.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(k_nope_in.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(q_rope_out.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(q_nope_out.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_rope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_nope_in.dtype()); + + // Validate input/output dtypes + TVM_FFI_ICHECK(q_rope_in.dtype() == dl_float16 || q_rope_in.dtype() == dl_bfloat16) + << "Input dtype must be float16 or bfloat16"; + TVM_FFI_ICHECK(q_rope_out.dtype() == dl_float8_e4m3fn || q_rope_out.dtype() == dl_float8_e5m2) + << "Output dtype must be float8_e4m3fn or float8_e5m2"; + + // Q tensors are always 3D + CHECK_DIM(3, q_rope_in); + CHECK_DIM(3, q_nope_in); + CHECK_DIM(3, q_rope_out); + CHECK_DIM(3, q_nope_out); + + // Detect architecture: MLA (2D K) vs GQA/MHA (3D K) + bool is_mla = (k_rope_in.ndim() == 2); + uint32_t num_kv_heads; + uint32_t batch_size = kv_indptr.size(0) - 1; + QKVLayout kv_layout = QKVLayout(kv_layout_code); + + if (is_mla) { + // MLA: K tensors are 2D + CHECK_DIM(2, k_rope_in); + CHECK_DIM(2, k_nope_in); + num_kv_heads = 1; + // V can be empty for MLA + TVM_FFI_ICHECK(v_in.data_ptr() == nullptr || v_in.size(0) == 0) + << "MLA should not have V input (or it should be empty)"; + // Validate MLA cache tensors are provided + TVM_FFI_ICHECK(ckv_cache.data_ptr() != nullptr && kpe_cache.data_ptr() != nullptr) + << "MLA requires ckv_cache and kpe_cache"; + CHECK_DIM(3, ckv_cache); // (max_pages, page_size, ckv_dim) + CHECK_DIM(3, kpe_cache); // (max_pages, page_size, kpe_dim) + TVM_FFI_ICHECK_EQ(ckv_cache.size(2), no_rope_dim); + TVM_FFI_ICHECK_EQ(kpe_cache.size(2), rope_dim); + } else { + // GQA/MHA: K tensors are 3D + CHECK_DIM(3, k_rope_in); + CHECK_DIM(3, k_nope_in); + num_kv_heads = k_rope_in.size(1); + TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads); + // V is required for GQA/MHA + CHECK_DIM(3, v_in); + TVM_FFI_ICHECK_EQ(v_in.size(0), nnz); + TVM_FFI_ICHECK_EQ(v_in.size(1), num_kv_heads); + // Validate GQA/MHA cache tensors are provided + TVM_FFI_ICHECK(k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) + << "GQA/MHA requires k_cache and v_cache"; + // Cache must be 4D + CHECK_DIM(4, k_cache); + CHECK_DIM(4, v_cache); + } + + // Extract Q strides + const uint32_t q_rope_in_stride_n = q_rope_in.stride(0); + const uint32_t q_rope_in_stride_h = q_rope_in.stride(1); + const uint32_t q_nope_in_stride_n = q_nope_in.stride(0); + const uint32_t q_nope_in_stride_h = q_nope_in.stride(1); + const uint32_t q_rope_out_stride_n = q_rope_out.stride(0); + const uint32_t q_rope_out_stride_h = q_rope_out.stride(1); + const uint32_t q_nope_out_stride_n = q_nope_out.stride(0); + const uint32_t q_nope_out_stride_h = q_nope_out.stride(1); + + // Extract K strides (architecture dependent) + uint32_t k_rope_in_stride, k_nope_in_stride; + uint32_t k_rope_in_stride_h, k_nope_in_stride_h; + uint32_t v_in_stride = 0, v_in_stride_h = 0; + + if (is_mla) { + // MLA: 2D K tensors + k_rope_in_stride = k_rope_in.stride(0); + k_nope_in_stride = k_nope_in.stride(0); + k_rope_in_stride_h = k_rope_in_stride; // Same as batch stride for 2D + k_nope_in_stride_h = k_nope_in_stride; + } else { + // GQA/MHA: 3D K tensors + k_rope_in_stride = k_rope_in.stride(0); + k_rope_in_stride_h = k_rope_in.stride(1); + k_nope_in_stride = k_nope_in.stride(0); + k_nope_in_stride_h = k_nope_in.stride(1); + v_in_stride = v_in.stride(0); + v_in_stride_h = v_in.stride(1); + } + + cudaSetDevice(q_rope_in.device().device_id); + const cudaStream_t stream = get_stream(q_rope_in.device()); + + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] { + cudaError_t status; + + if (is_mla) { + // MLA: Construct paged_kv_mla_t struct + auto ckv_strides = ckv_cache.strides(); + auto kpe_strides = kpe_cache.strides(); + + paged_kv_mla_t paged_kv_mla( + page_size, no_rope_dim, rope_dim, batch_size, + static_cast(ckv_cache.data_ptr()), ckv_strides.data(), + static_cast(kpe_cache.data_ptr()), kpe_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedMLACache( + static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv_mla, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, + q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, + q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, + k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave, + enable_pdl, stream); + + } else { + // GQA/MHA: Construct paged_kv_t struct + auto k_strides = k_cache.strides(); + auto v_strides = v_cache.strides(); + uint32_t head_dim = rope_dim + no_rope_dim; + + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout, + static_cast(k_cache.data_ptr()), + static_cast(v_cache.data_ptr()), k_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedKVCache( + static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), + static_cast(v_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, + no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, + q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, + q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, + k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv, + interleave, enable_pdl, stream); + } + + TVM_FFI_ICHECK(status == cudaSuccess) + << "RopeQuantizeAppendPagedKVCache failed with error code " << cudaGetErrorString(status); + return true; + }); + }); +} diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 7884c439be..72fa3d2ad9 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -226,6 +226,105 @@ def _fake_rope_quantize( pass +@register_custom_op( + "flashinfer::rope_quantize_append_paged_kv_cache", + mutates_args=( + "q_rope_out", + "q_nope_out", + "k_cache", + "v_cache", + "ckv_cache", + "kpe_cache", + ), +) +def _rope_quantize_fp8_append_paged_kv_cache( + q_rope_in: torch.Tensor, + k_rope_in: torch.Tensor, + q_nope_in: torch.Tensor, + k_nope_in: torch.Tensor, + v_in: torch.Tensor, + q_rope_out: torch.Tensor, + q_nope_out: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + ckv_cache: torch.Tensor, + kpe_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + kv_layout_code: int, + page_size: int, + quant_scale_q: float, + quant_scale_kv: float, + interleave: bool, + enable_pdl: bool, +) -> None: + r"""Custom operator that routes to the CUDA kernel implementation. + + Fuses RoPE application, FP8 quantization, and paged KV cache append into a single kernel. + + Converts is_neox parameter to interleave format and dispatches to the underlying + CUDA kernel via the JIT-compiled module. + """ + get_rope_module().rope_quantize_append_paged_kv_cache( + q_rope_in, + k_rope_in, + q_nope_in, + k_nope_in, + v_in, + q_rope_out, + q_nope_out, + cos_sin_cache, + pos_ids, + k_cache, + v_cache, + ckv_cache, + kpe_cache, + kv_indices, + kv_indptr, + batch_indices, + positions, + kv_layout_code, + page_size, + quant_scale_q, + quant_scale_kv, + interleave, + enable_pdl, + ) + + +@register_fake_op("flashinfer::rope_quantize_fp8_append_paged_kv_cache") +def _fake_rope_quantize_fp8_append_paged_kv_cache( + q_rope_in: torch.Tensor, + k_rope_in: torch.Tensor, + q_nope_in: torch.Tensor, + k_nope_in: torch.Tensor, + v_in: torch.Tensor, + q_rope_out: torch.Tensor, + q_nope_out: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + ckv_cache: torch.Tensor, + kpe_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + kv_layout_code: int, + page_size: int, + quant_scale_q: float, + quant_scale_kv: float, + interleave: bool, + enable_pdl: bool, +) -> None: + pass + + @register_custom_op( "flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope") ) @@ -1303,3 +1402,269 @@ def rope_quantize_fp8( ) return q_rope_out, k_rope_out, q_nope_out, k_nope_out + + +def rope_quantize_fp8_append_paged_kv_cache( + q_rope: torch.Tensor, + k_rope: torch.Tensor, + q_nope: torch.Tensor, + k_nope: torch.Tensor, + v: Optional[torch.Tensor], + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + paged_kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + is_neox: bool = True, + quantize_dtype: Optional[torch.dtype] = None, + quant_scale_q: float = 1.0, + quant_scale_kv: float = 1.0, + page_size: int = 16, + kv_layout: str = "NHD", + q_rope_out: Optional[torch.Tensor] = None, + q_nope_out: Optional[torch.Tensor] = None, + enable_pdl: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Apply RoPE (Rotary Positional Embeddings), quantize to FP8, and append K/V to paged cache. + + This fused function applies RoPE to query/key (Q/K) rotary dimension tensors, quantizes all Q/K tensors + (and V for GQA/MHA) to FP8 format, and directly appends the quantized K/V to a paged KV cache. + It returns quantized Q tensors for use in attention computation. Supports MLA, GQA, and MHA + architectures with automatic detection based on input tensor shapes. + + Parameters + ---------- + q_rope : torch.Tensor + Query tensor (rotary dimensions), shape: ``(nnz, num_qo_heads, rope_dim)``. + Must be float16 or bfloat16. + k_rope : torch.Tensor + Key tensor (rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, rope_dim)``. + For MLA: ``(nnz, rope_dim)``. Must be float16 or bfloat16. + q_nope : torch.Tensor + Query tensor (non-rotary dimensions), shape: ``(nnz, num_qo_heads, no_rope_dim)``. + Must be float16 or bfloat16. + k_nope : torch.Tensor + Key tensor (non-rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, no_rope_dim)``. + For MLA: ``(nnz, no_rope_dim)``. Must be float16 or bfloat16. + v : Optional[torch.Tensor] + Value tensor for GQA/MHA: ``(nnz, num_kv_heads, head_dim)``. Must be float16 or bfloat16. + For MLA: pass ``None`` (MLA does not use separate V; K non-RoPE acts as compressed KV). + cos_sin_cache : torch.Tensor + Precomputed cosine and sine values, shape: ``(max_seq_len, rope_dim)``. + First half contains cosine values, second half contains sine values. Must be float32. + pos_ids : torch.Tensor + Position indices for each token, shape: ``(nnz,)``. + paged_kv_cache : Tuple[torch.Tensor, torch.Tensor] + For MLA: ``(ckv_cache, kpe_cache)`` where: + - ckv_cache: ``(max_pages, page_size, no_rope_dim)`` in FP8 + - kpe_cache: ``(max_pages, page_size, rope_dim)`` in FP8 + For GQA/MHA: ``(k_cache, v_cache)`` where: + - k_cache: ``(max_pages, page_size, num_kv_heads, head_dim)`` or + ``(max_pages, num_kv_heads, page_size, head_dim)`` depending on layout, in FP8 + - v_cache: same shape as k_cache, in FP8 + kv_indices : torch.Tensor + Page indices mapping, shape: ``(total_pages,)``. Typically ``torch.arange(total_pages)``. + kv_indptr : torch.Tensor + Page indptr array for each request, shape: ``(batch_size + 1,)``. + ``kv_indptr[i]`` is the starting page index for request ``i``. + batch_indices : torch.Tensor + Batch index for each token, shape: ``(nnz,)``. Maps each token to its request. + positions : torch.Tensor + Position within each request's sequence for each token, shape: ``(nnz,)``. + is_neox : bool + RoPE layout style. If ``True`` (default), use non-interleaved layout (first/second half). + If ``False``, use interleaved layout (even/odd dimensions). + quantize_dtype : Optional[torch.dtype] + Target quantization dtype. If ``None``, inferred from output tensors or defaults to + ``torch.float8_e4m3fn``. Must be ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. + quant_scale_q : float + Quantization scaling factor for query tensors, default: ``1.0``. + quant_scale_kv : float + Quantization scaling factor for key/value tensors, default: ``1.0``. + page_size : int + Number of entries per page in the paged cache, default: ``16``. + kv_layout : str + Cache memory layout for GQA/MHA. Options: ``"NHD"`` (page, seq, head, dim) or + ``"HND"`` (page, head, seq, dim). Default: ``"NHD"``. Ignored for MLA. + q_rope_out : Optional[torch.Tensor] + Pre-allocated output tensor for quantized query (rotary). If ``None``, allocated automatically. + q_nope_out : Optional[torch.Tensor] + Pre-allocated output tensor for quantized query (non-rotary). If ``None``, allocated automatically. + enable_pdl : bool + Whether to enable PDL (Programmatic Dependent Launch). Default: ``False``. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Quantized query tensors: (q_rope_out, q_nope_out). + K/V are written directly to the paged cache and not returned. + + Examples + -------- + MLA example: + + >>> import torch + >>> import flashinfer + >>> # MLA setup: 2D K tensors + >>> num_tokens, num_qo_heads, rope_dim, no_rope_dim = 32, 128, 64, 512 + >>> q_rope = torch.randn(num_tokens, num_qo_heads, rope_dim, dtype=torch.float16, device="cuda") + >>> q_nope = torch.randn(num_tokens, num_qo_heads, no_rope_dim, dtype=torch.float16, device="cuda") + >>> k_rope = torch.randn(num_tokens, rope_dim, dtype=torch.float16, device="cuda") + >>> k_nope = torch.randn(num_tokens, no_rope_dim, dtype=torch.float16, device="cuda") + >>> # Allocate MLA paged cache + >>> max_pages, page_size = 10, 16 + >>> ckv_cache = torch.zeros(max_pages, page_size, no_rope_dim, dtype=torch.float8_e4m3fn, device="cuda") + >>> kpe_cache = torch.zeros(max_pages, page_size, rope_dim, dtype=torch.float8_e4m3fn, device="cuda") + >>> # Setup RoPE and metadata + >>> rope_emb = flashinfer.rope.FlashInferRotaryEmbedding(rope_dim + no_rope_dim, rope_dim, 4096, 10000, False, torch.float16, "cuda") + >>> pos_ids = torch.arange(num_tokens, device="cuda", dtype=torch.int32) + >>> kv_page_indices = torch.arange(max_pages, device="cuda", dtype=torch.int32) + >>> kv_page_indptr = torch.tensor([0, max_pages], device="cuda", dtype=torch.int32) + >>> batch_indices = torch.zeros(num_tokens, device="cuda", dtype=torch.int32) + >>> positions = torch.arange(num_tokens, device="cuda", dtype=torch.int32) + >>> # Fused call + >>> q_rope_out, q_nope_out = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + ... q_rope, k_rope, q_nope, k_nope, None, + ... rope_emb.cos_sin_cache, pos_ids, + ... (ckv_cache, kpe_cache), + ... kv_page_indices, kv_page_indptr, batch_indices, positions, + ... is_neox=False, quantize_dtype=torch.float8_e4m3fn, + ... page_size=page_size + ... ) + + GQA example: + + >>> # GQA setup: 3D K/V tensors + >>> num_tokens, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim = 32, 32, 8, 64, 64 + >>> head_dim = rope_dim + no_rope_dim + >>> q_rope = torch.randn(num_tokens, num_qo_heads, rope_dim, dtype=torch.float16, device="cuda") + >>> q_nope = torch.randn(num_tokens, num_qo_heads, no_rope_dim, dtype=torch.float16, device="cuda") + >>> k_rope = torch.randn(num_tokens, num_kv_heads, rope_dim, dtype=torch.float16, device="cuda") + >>> k_nope = torch.randn(num_tokens, num_kv_heads, no_rope_dim, dtype=torch.float16, device="cuda") + >>> v = torch.randn(num_tokens, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") + >>> # Allocate GQA paged cache (NHD layout) + >>> max_pages, page_size = 10, 16 + >>> k_cache = torch.zeros(max_pages, page_size, num_kv_heads, head_dim, dtype=torch.float8_e4m3fn, device="cuda") + >>> v_cache = torch.zeros(max_pages, page_size, num_kv_heads, head_dim, dtype=torch.float8_e4m3fn, device="cuda") + >>> # Fused call + >>> q_rope_out, q_nope_out = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + ... q_rope, k_rope, q_nope, k_nope, v, + ... rope_emb.cos_sin_cache, pos_ids, + ... (k_cache, v_cache), + ... kv_page_indices, kv_page_indptr, batch_indices, positions, + ... is_neox=False, quantize_dtype=torch.float8_e4m3fn, + ... page_size=page_size, kv_layout="NHD" + ... ) + + Notes + ----- + - Architecture detection: Automatically distinguishes MLA (2D K tensors) from GQA/MHA (3D K tensors). + - MLA writes K-RoPE to ``kpe_cache`` and K-noRoPE to ``ckv_cache``; V is not used. + - GQA/MHA writes full K (RoPE+noRoPE) to ``k_cache`` and V to ``v_cache``. + - The ``batch_indices`` and ``positions`` tensors are typically obtained from + ``flashinfer.get_batch_indices_positions()``. + - Cache tensors must already be allocated in the target FP8 dtype. + """ + if cos_sin_cache.dtype != torch.float32: + raise ValueError("cos_sin_cache should be float32") + + # Detect architecture + is_mla = k_rope.ndim == 2 + + # Infer quantize_dtype from output tensors or default + if quantize_dtype is None: + if q_rope_out is not None: + quantize_dtype = q_rope_out.dtype + elif q_nope_out is not None: + quantize_dtype = q_nope_out.dtype + else: + quantize_dtype = torch.float8_e4m3fn + + # Allocate Q output tensors if not provided + if q_rope_out is None: + q_rope_out = torch.empty_like(q_rope, dtype=quantize_dtype) + if q_nope_out is None: + q_nope_out = torch.empty_like(q_nope, dtype=quantize_dtype) + + # Handle V input for MLA (create empty dummy tensor, not used) + if is_mla: + if v is None: + v = torch.empty(0, dtype=q_rope.dtype, device=q_rope.device) + else: + raise ValueError("MLA should not have V input (pass None)") + + # Unpack and validate cache tensors + if len(paged_kv_cache) != 2: + raise ValueError("paged_kv_cache must be a tuple of 2 tensors") + + cache_0, cache_1 = paged_kv_cache + + if is_mla: + # MLA: Expect (ckv_cache, kpe_cache) + ckv_cache = cache_0 + kpe_cache = cache_1 + if ckv_cache.dtype != quantize_dtype or kpe_cache.dtype != quantize_dtype: + raise ValueError( + f"MLA cache dtype mismatch: expected {quantize_dtype}, " + f"got ckv={ckv_cache.dtype}, kpe={kpe_cache.dtype}" + ) + if ckv_cache.ndim != 3 or kpe_cache.ndim != 3: + raise ValueError( + f"MLA cache must be 3D: (max_pages, page_size, dim), " + f"got ckv={ckv_cache.ndim}D, kpe={kpe_cache.ndim}D" + ) + # Create dummy tensors for GQA/MHA cache (not used) + k_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + v_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + else: + # GQA/MHA: Expect (k_cache, v_cache) + k_cache = cache_0 + v_cache = cache_1 + if k_cache.dtype != quantize_dtype or v_cache.dtype != quantize_dtype: + raise ValueError( + f"GQA/MHA cache dtype mismatch: expected {quantize_dtype}, " + f"got k={k_cache.dtype}, v={v_cache.dtype}" + ) + if k_cache.ndim not in [4, 5] or v_cache.ndim not in [4, 5]: + raise ValueError( + f"GQA/MHA cache must be 4D or 5D, got k={k_cache.ndim}D, v={v_cache.ndim}D" + ) + # Create dummy tensors for MLA cache (not used) + ckv_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + kpe_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + + # Import TensorLayout enum + from .utils import TensorLayout + + kv_layout_code = TensorLayout[kv_layout].value + + # Call custom op + _rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + v, + q_rope_out, + q_nope_out, + cos_sin_cache, + pos_ids, + k_cache, + v_cache, + ckv_cache, + kpe_cache, + kv_indices, + kv_indptr, + batch_indices, + positions, + kv_layout_code, + page_size, + quant_scale_q, + quant_scale_kv, + not is_neox, # interleave + enable_pdl, + ) + + return q_rope_out, q_nope_out diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 7547a06090..478cb39d02 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -23,6 +23,7 @@ #include "layout.cuh" #include "math.cuh" +#include "page.cuh" #include "utils.cuh" #include "vec_dtypes.cuh" @@ -746,6 +747,217 @@ __global__ void BatchQKApplyRotaryKernel( FLASHINFER_ERROR("Unsupported rope_dim. Supported values: 16, 32, 64, 128, 256"); \ } +/*! + * \brief Unified CUDA kernel to apply RoPE, quantize to FP8, and append to paged cache. + * + * Templated on CacheT to support both GQA/MHA (paged_kv_t) and MLA (paged_kv_mla_t). + * Cache-only behaviors are selected with constexpr on the CacheT. + */ +template +__global__ void RopeQuantizeAppendPagedKVCacheKernel( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, + QuantType* q_rope_out, QuantType* q_nope_out, CacheT paged_kv_like, + IdType* __restrict__ batch_indices, IdType* __restrict__ positions, + float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rope_dim, uint32_t no_rope_dim, + size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n, + size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h, + size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, + size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h, + size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv) { +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + uint32_t by = blockIdx.y; + uint32_t bdy = blockDim.y; + + // Calculate flexible boundaries for block allocation + uint32_t rope_chunk_size = rope_dim; + uint32_t rope_chunks = (rope_dim + rope_chunk_size - 1) / rope_chunk_size; + uint32_t no_rope_chunks = (no_rope_dim + rope_chunk_size - 1) / rope_chunk_size; + + uint32_t q_rope_end = num_qo_heads * rope_chunks; + // For MLA, num_kv_heads is effectively 1 + uint32_t k_rope_end = q_rope_end + num_kv_heads * rope_chunks; + uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks; + + vec_t cos, sin; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + + // Compute page location for this token + uint32_t page_iter, entry_idx; + paged_kv_like.page_size.divmod( + paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx], + page_iter, entry_idx); + + const int half_rope_dim = rope_dim / 2; + // Load cos/sin for RoPE processing blocks only + if ((tx * vec_size < rope_dim) and (by < k_rope_end)) { + int sin_offset = rope_dim / 2; + int vec_idx; + if constexpr (interleave) { + vec_idx = (tx * vec_size) / 2; // Force integer division + } else { + vec_idx = (tx * vec_size) % half_rope_dim; + } + cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx); + sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx)); + } + + if (by < q_rope_end) { + // ============ Q RoPE processing ============ + uint32_t q_head_idx = by / rope_chunks; + uint32_t rope_chunk_idx = by % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + + DType* q_rope_in_ptr = + q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n, + q_rope_in_stride_h); + QuantType* q_rope_out_ptr = + q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n, + q_rope_out_stride_h); + + vec_t q_rope_vec; + if constexpr (interleave) { + q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + q_rope_in_ptr, cos, sin, rope_dim); + } else { + q_rope_vec = vec_apply_llama_rope_cos_sin(q_rope_in_ptr, cos, sin, rope_dim); + } +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + q_rope_vec[i] = q_rope_vec[i] * quant_scale_q; + } + q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); + + } else if (by < k_rope_end) { + // ============ K RoPE processing & Cache Append ============ + uint32_t k_head_idx = (by - q_rope_end) / rope_chunks; + uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + + DType* k_rope_in_ptr; + if constexpr (IS_MLA) { + // MLA: 2D K + k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset; + } else { + // GQA/MHA: 3D K + k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_rope_in_stride, k_rope_in_stride_h); + } + + vec_t k_rope_vec; + if constexpr (interleave) { + k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + k_rope_in_ptr, cos, sin, rope_dim); + } else { + k_rope_vec = vec_apply_llama_rope_cos_sin(k_rope_in_ptr, cos, sin, rope_dim); + } +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv; + } + + if constexpr (IS_MLA) { + QuantType* kpe_ptr = + paged_kv_like.get_kpe_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); + k_rope_vec.cast_store(kpe_ptr); + } else { + QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size); + k_rope_vec.cast_store(k_ptr); + } + + } else if (by < k_nope_end) { + // ============ K Non-RoPE processing & Cache Append ============ + uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + + DType* k_nope_in_ptr; + if constexpr (IS_MLA) { + k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset; + } else { + k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_nope_in_stride, k_nope_in_stride_h); + } + + vec_t k_nope_vec; + k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; + } + + if constexpr (IS_MLA) { + QuantType* ckv_ptr = + paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); + k_nope_vec.cast_store(ckv_ptr); + } else { + QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, + rope_dim + elem_offset + tx * vec_size); + k_nope_vec.cast_store(k_ptr); + } + + } else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) { + // ============ V processing & Cache Append (GQA/MHA only) ============ + if constexpr (!IS_MLA) { + uint32_t kv_head_idx = by - k_nope_end; + DType* v_in_ptr = + v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h); + // Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size + uint32_t head_dim_total = rope_dim + no_rope_dim; + uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size; +#pragma unroll 1 + for (uint32_t j = 0; j < v_chunks; ++j) { + uint32_t v_elem_offset = j * rope_chunk_size; + if (v_elem_offset + tx * vec_size < head_dim_total) { + vec_t v_vec; + v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + v_vec[i] = v_vec[i] * quant_scale_kv; + } + QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, + v_elem_offset + tx * vec_size); + v_vec.cast_store(v_ptr); + } + } + } + + } else { + // ============ Q Non-RoPE processing ============ + // MLA has no V section, so Q-nope starts immediately after K-nope. + // GQA/MHA has a V section of length num_kv_heads blocks. + uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads); + uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + + DType* q_nope_in_ptr = + q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n, + q_nope_in_stride_h); + QuantType* q_nope_out_ptr = + q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n, + q_nope_out_stride_h); + + vec_t q_nope_vec; + q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; + } + q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); + } + } +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + template cudaError_t RopeQuantize( DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out, @@ -838,6 +1050,159 @@ cudaError_t RopeQuantize( return cudaSuccess; } +/*! + * \brief Host function to apply RoPE, quantize to FP8, and append K/V to paged cache (GQA/MHA) + */ +template +cudaError_t RopeQuantizeAppendPagedKVCache( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, + QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t paged_kv, + IdType* batch_indices, IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rope_dim, uint32_t no_rope_dim, + size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n, + size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h, + size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, + size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h, + size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv, + bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { + constexpr uint32_t vec_size = 32 / sizeof(DType); + + DISPATCH_ROPE_DIM(rope_dim, vec_size, { + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + uint32_t num_threads = 128U; + uint32_t bdy = num_threads / bdx; + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunks = 1; + uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; + + // GQA/MHA: Q rope + K rope + K nope + V + Q nope + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_kv_heads + + num_qo_heads * no_rope_chunks; + + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + auto kernel = + RopeQuantizeAppendPagedKVCacheKernel>; + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( + &config, kernel, + // inputs + q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in, + // q outputs + q_rope_out, q_nope_out, + // cache + indices + paged_kv, batch_indices, positions, + // rope tables + cos_sin_cache, pos_ids, + // sizes + nnz, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, + // Q strides (in/out) + q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, + q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, + // K strides + k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, k_nope_in_stride_h, + // V strides + v_in_stride, v_in_stride_h, + // scales + quant_scale_q, quant_scale_kv)); + }); + }); + + return cudaSuccess; +} + +/*! + * \brief Host function to apply RoPE, quantize to FP8, and append to MLA paged cache + */ +template +cudaError_t RopeQuantizeAppendPagedMLACache( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out, + QuantType* q_nope_out, paged_kv_mla_t paged_kv_mla, IdType* batch_indices, + IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, + uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, + size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, + size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, + size_t k_rope_in_stride, size_t k_nope_in_stride, float quant_scale_q, float quant_scale_kv, + bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { + constexpr uint32_t vec_size = 32 / sizeof(DType); + + DISPATCH_ROPE_DIM(rope_dim, vec_size, { + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + uint32_t num_threads = 128U; + uint32_t bdy = num_threads / bdx; + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunks = 1; + uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; + + // MLA: Q rope + K rope + K nope + Q nope (no V) + constexpr uint32_t num_kv_heads = 1; + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_qo_heads * no_rope_chunks; + + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + auto kernel = + RopeQuantizeAppendPagedKVCacheKernel>; + // For MLA: pass v_in as nullptr, num_kv_heads=1, duplicate 2D K strides for head strides, and + // 0 V strides + DType* v_in_nullptr = nullptr; + uint32_t num_kv_heads_1 = 1; + size_t k_rope_in_stride_h_dup = k_rope_in_stride; + size_t k_nope_in_stride_h_dup = k_nope_in_stride; + size_t v_in_stride_zero = 0, v_in_stride_h_zero = 0; + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( + &config, kernel, + // inputs + q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in_nullptr, + // q outputs + q_rope_out, q_nope_out, + // cache + indices + paged_kv_mla, batch_indices, positions, + // rope tables + cos_sin_cache, pos_ids, + // sizes + nnz, num_qo_heads, num_kv_heads_1, rope_dim, no_rope_dim, + // Q strides (in/out) + q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, + q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, + // K strides (2D: duplicate for head stride) + k_rope_in_stride, k_rope_in_stride_h_dup, k_nope_in_stride, k_nope_in_stride_h_dup, + // V strides (unused for MLA) + v_in_stride_zero, v_in_stride_h_zero, + // scales + quant_scale_q, quant_scale_kv)); + }); + }); + + return cudaSuccess; +} + template cudaError_t BatchQKApplyRotaryPosIdsCosSinCache( DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_sin_cache, IdType* pos_ids, diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index da59223a4f..4a8f443775 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -394,6 +394,10 @@ def test_generalized_rope_quantize( ): """Test generalized rope + quantization for MLA, GQA, and MHA architectures.""" device = "cuda:0" + # Fixed seed for reproducibility across tests + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) total_dim = rope_dim + no_rope_dim # Create input tensors based on attention type @@ -481,6 +485,312 @@ def test_generalized_rope_quantize( ) +@pytest.mark.parametrize( + "attention_type,num_qo_heads,num_kv_heads,rope_dim,no_rope_dim", + [ + # MLA: Multiple Q heads, single shared K/V head + ("mla", 128, 1, 64, 512), + ("mla", 64, 1, 128, 256), + ("mla", 128, 1, 64, 128), # Explicit DeepSeek R1 MLA config case + ("mla", 32, 1, 32, 96), + # GQA: Multiple Q heads, fewer K/V heads (grouped) + ("gqa", 32, 8, 64, 64), + ("gqa", 64, 16, 128, 128), + ("gqa", 24, 6, 32, 96), + ("gqa", 32, 8, 128, 0), # Llama3 8B standard config + ("gqa", 64, 8, 128, 0), # Llama3 70B standard config + ("gqa", 64, 8, 64, 0), # (plausible) GPT-OSS config + # MHA: Equal Q and K/V heads + ("mha", 32, 32, 64, 64), + ("mha", 16, 16, 128, 128), + ("mha", 8, 8, 32, 96), + ], +) +@pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) +@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("enable_pdl", [True, False]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +def test_generalized_rope_quantize_append_kv_cache( + attention_type, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + num_tokens, + input_dtype, + quant_dtype, + enable_pdl, + kv_layout, +): + device = "cuda:0" + # Fixed seed for reproducibility + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + head_dim = rope_dim + no_rope_dim + page_size = 16 + batch_size = 4 + + # Build inputs following the same pattern used elsewhere + if attention_type == "mla": + # Q: (N, Hq, *), K: 2D (N, *)x + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn(num_tokens, rope_dim, dtype=input_dtype, device=device) + k_nope = torch.randn(num_tokens, no_rope_dim, dtype=input_dtype, device=device) + v = None + else: + # GQA/MHA: K/V are 3D + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn( + num_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope = torch.randn( + num_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + v = torch.randn( + num_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # Cos/sin and positions + max_seq_len = 4096 + rope_ref = FlashInferRotaryEmbedding( + head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device + ) + pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Build paged metadata + kv_append_length = torch.tensor( + [num_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length, dim=0), + ] + ) + num_pages_per_req = torch.tensor( + [(num_tokens + page_size - 1) // page_size] + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(num_pages_per_req, dim=0), + ] + ) + kv_page_indices = torch.arange( + kv_page_indptr[-1].item(), dtype=torch.int32, device=device + ) + kv_last_page_len = torch.tensor( + [num_tokens % page_size if num_tokens % page_size != 0 else page_size] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + # Allocate caches sized by required pages + max_pages = kv_page_indptr[-1].item() + + # Get batch_indices and positions + seq_lens = flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size) + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, seq_lens, num_tokens + ) + + # Fused call + cache allocation + if attention_type == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + q_rope_out_fused, q_nope_out_fused = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + None, + rope_ref.cos_sin_cache, + pos_ids, + (ckv_cache, kpe_cache), + kv_page_indices, + kv_page_indptr, + batch_indices, + positions, + page_size=page_size, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + else: + # Allocate cache based on layout + if kv_layout == "NHD": + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + else: # HND + k_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + q_rope_out_fused, q_nope_out_fused = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + v, + rope_ref.cos_sin_cache, + pos_ids, + (k_cache, v_cache), + kv_page_indices, + kv_page_indptr, + batch_indices, + positions, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + # Compute reference output + q_in = torch.cat([q_rope, q_nope], dim=-1) + k_in = torch.cat([k_rope, k_nope], dim=-1) + q_out_f16_ref, k_out_f16_ref = rope_ref.forward_native(pos_ids, q_in, k_in) + q_out_f8_ref, k_out_f8_ref = map( + lambda x: x.to(quant_dtype), + (q_out_f16_ref, k_out_f16_ref), + ) + + # Fused vs Pytorch reference Q checks + torch.testing.assert_close( + q_out_f8_ref[..., :rope_dim].float(), + q_rope_out_fused.float(), + rtol=2e-1, + atol=1e-2, + ) + torch.testing.assert_close( + q_out_f8_ref[..., rope_dim:].float(), + q_nope_out_fused.float(), + rtol=2e-1, + atol=1e-2, + ) + + # expect 1-ULP differences between FP8 device rounding and PyTorch .to(fp8) + if quant_dtype == torch.float8_e4m3fn: + rtol_val, atol_val = 0.25, 0.5 + else: # quant_dtype == torch.float8_e5m2: + rtol_val, atol_val = 0.25, 1.0 + + # if MLA: check ckv_cache, kpe_cache + if attention_type == "mla": + # Split K reference + k_rope_ref = k_out_f8_ref[..., :rope_dim] + k_nope_ref = k_out_f8_ref[..., rope_dim:] + + ckv_ref = torch.zeros_like(ckv_cache) + kpe_ref = torch.zeros_like(kpe_cache) + + for i in range(num_tokens): + b = batch_indices[i].item() + pos = positions[i].item() + page_iter = (kv_page_indptr[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices[page_iter].item() + ckv_ref[page_idx, entry_idx, :] = k_nope_ref[i] + kpe_ref[page_idx, entry_idx, :] = k_rope_ref[i] + + torch.testing.assert_close( + ckv_cache.float(), ckv_ref.float(), rtol=rtol_val, atol=atol_val + ) + torch.testing.assert_close( + kpe_cache.float(), kpe_ref.float(), rtol=rtol_val, atol=atol_val + ) + + # if GQA/MHA: check k_cache, v_cache + if attention_type == "gqa" or attention_type == "mha": + # K reference + k_ref = torch.zeros_like(k_cache) + for i in range(num_tokens): + b = batch_indices[i].item() + pos = positions[i].item() + page_iter = (kv_page_indptr[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices[page_iter].item() + if kv_layout == "NHD": + k_ref[page_idx, entry_idx, :, :] = k_out_f8_ref[i] # [Hkv, head_dim] + else: # HND + k_ref[page_idx, :, entry_idx, :] = k_out_f8_ref[i] # [Hkv, head_dim] + + torch.testing.assert_close( + k_cache.float(), k_ref.float(), rtol=rtol_val, atol=atol_val + ) + + # V reference (no RoPE on V; same quant scale as KV) + quant_scale_kv = 1.0 # match fused call + v_ref_tokens = (v * quant_scale_kv).to(quant_dtype) + v_ref = torch.zeros_like(v_cache) + for i in range(num_tokens): + b = batch_indices[i].item() + pos = positions[i].item() + page_iter = (kv_page_indptr[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices[page_iter].item() + if kv_layout == "NHD": + v_ref[page_idx, entry_idx, :, :] = v_ref_tokens[i] + else: # HND + v_ref[page_idx, :, entry_idx, :] = v_ref_tokens[i] + + torch.testing.assert_close( + v_cache.float(), v_ref.float(), rtol=rtol_val, atol=atol_val + ) + + @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @@ -492,6 +802,10 @@ def test_mla_rope_quantize( enable_pdl, ): device = "cuda:0" + # Fixed seed for reproducibility across tests + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) num_qo_heads = 128 q_in = torch.randn(num_tokens, num_qo_heads, 576, dtype=input_dtype, device=device) k_in = torch.randn(num_tokens, 576, dtype=input_dtype, device=device) From 4040e9c86159741d3c335037b3529b9a8ddb96fa Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Tue, 4 Nov 2025 10:59:44 -0800 Subject: [PATCH 02/13] add a decode test --- csrc/flashinfer_rope_binding.cu | 11 +- tests/attention/test_rope.py | 511 ++++++++++++++++++++++++++++++++ 2 files changed, 515 insertions(+), 7 deletions(-) diff --git a/csrc/flashinfer_rope_binding.cu b/csrc/flashinfer_rope_binding.cu index e58deda5c0..94809da735 100644 --- a/csrc/flashinfer_rope_binding.cu +++ b/csrc/flashinfer_rope_binding.cu @@ -45,16 +45,13 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope TensorView pos_ids, double quant_scale_q, double quant_scale_kv, bool interleave, bool enable_pdl); -// Fused RoPE + Quantize + Append Paged KV Cache (MLA/GQA/MHA) void rope_quantize_append_paged_kv_cache( TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in, TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache, - TensorView pos_ids, - // Paged cache tensors - TensorView k_cache, TensorView v_cache, TensorView ckv_cache, TensorView kpe_cache, - TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, TensorView positions, - int64_t kv_layout_code, int64_t page_size, double quant_scale_q, double quant_scale_kv, - bool interleave, bool enable_pdl); + TensorView pos_ids, TensorView k_cache, TensorView v_cache, TensorView ckv_cache, + TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, + TensorView positions, int64_t kv_layout_code, int64_t page_size, double quant_scale_q, + double quant_scale_kv, bool interleave, bool enable_pdl); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope); diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index 4a8f443775..917adba3ba 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -791,6 +791,517 @@ def test_generalized_rope_quantize_append_kv_cache( ) +@pytest.mark.parametrize( + "attention_type,num_qo_heads,num_kv_heads,rope_dim,no_rope_dim", + [ + # MLA: Multiple Q heads, single shared K/V head + ("mla", 128, 1, 64, 512), + ("mla", 32, 1, 32, 96), + # GQA: Multiple Q heads, fewer K/V heads (grouped) + ("gqa", 32, 8, 64, 64), + ("gqa", 32, 8, 128, 0), # Llama3 8B standard config + # MHA: Equal Q and K/V heads + ("mha", 32, 32, 64, 64), + ("mha", 16, 16, 128, 128), + ], +) +@pytest.mark.parametrize("num_existing_tokens", [10, 50]) +@pytest.mark.parametrize("num_new_tokens", [1, 8]) +@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("enable_pdl", [True, False]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +def test_rope_quantize_fp8_append_paged_kv_cache_decode( + attention_type, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + num_existing_tokens, + num_new_tokens, + input_dtype, + quant_dtype, + enable_pdl, + kv_layout, +): + """Test append to non-empty cache (decode/continuation scenario).""" + device = "cuda:0" + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + head_dim = rope_dim + no_rope_dim + page_size = 16 + batch_size = 2 + + # Step 1: Pre-populate cache with existing tokens + if attention_type == "mla": + q_rope_existing = torch.randn( + num_existing_tokens, + num_qo_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + q_nope_existing = torch.randn( + num_existing_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + k_rope_existing = torch.randn( + num_existing_tokens, rope_dim, dtype=input_dtype, device=device + ) + k_nope_existing = torch.randn( + num_existing_tokens, no_rope_dim, dtype=input_dtype, device=device + ) + v_existing = None + else: + q_rope_existing = torch.randn( + num_existing_tokens, + num_qo_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + q_nope_existing = torch.randn( + num_existing_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + k_rope_existing = torch.randn( + num_existing_tokens, + num_kv_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + k_nope_existing = torch.randn( + num_existing_tokens, + num_kv_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + v_existing = torch.randn( + num_existing_tokens, + num_kv_heads, + head_dim, + dtype=input_dtype, + device=device, + ) + + # Create RoPE reference + max_seq_len = 4096 + rope_ref = FlashInferRotaryEmbedding( + head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device + ) + pos_ids_existing = torch.arange( + num_existing_tokens, device=device, dtype=torch.int32 + ) + + # Build metadata for existing tokens (single request for simplicity) + kv_append_length_existing = torch.tensor( + [num_existing_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr_existing = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length_existing, dim=0), + ] + ) + num_pages_existing = (num_existing_tokens + page_size - 1) // page_size + kv_page_indptr_existing = torch.tensor( + [0, num_pages_existing] + [num_pages_existing] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indices_existing = torch.arange( + num_pages_existing, dtype=torch.int32, device=device + ) + kv_last_page_len_existing = torch.tensor( + [ + num_existing_tokens % page_size + if num_existing_tokens % page_size != 0 + else page_size + ] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + seq_lens_existing = flashinfer.get_seq_lens( + kv_page_indptr_existing, kv_last_page_len_existing, page_size + ) + batch_indices_existing, positions_existing = flashinfer.get_batch_indices_positions( + kv_append_indptr_existing, seq_lens_existing, num_existing_tokens + ) + + # Allocate cache sized for existing + new tokens + total_tokens = num_existing_tokens + num_new_tokens + max_pages = (total_tokens + page_size - 1) // page_size + + if attention_type == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + # Pre-populate with existing tokens + _, _ = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_existing, + k_rope_existing, + q_nope_existing, + k_nope_existing, + None, + rope_ref.cos_sin_cache, + pos_ids_existing, + (ckv_cache, kpe_cache), + kv_page_indices_existing, + kv_page_indptr_existing, + batch_indices_existing, + positions_existing, + page_size=page_size, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + else: + if kv_layout == "NHD": + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + else: # HND + k_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + # Pre-populate with existing tokens + _, _ = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_existing, + k_rope_existing, + q_nope_existing, + k_nope_existing, + v_existing, + rope_ref.cos_sin_cache, + pos_ids_existing, + (k_cache, v_cache), + kv_page_indices_existing, + kv_page_indptr_existing, + batch_indices_existing, + positions_existing, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + + # Step 2: Append new tokens to the pre-populated cache + if attention_type == "mla": + q_rope_new = torch.randn( + num_new_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope_new = torch.randn( + num_new_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope_new = torch.randn( + num_new_tokens, rope_dim, dtype=input_dtype, device=device + ) + k_nope_new = torch.randn( + num_new_tokens, no_rope_dim, dtype=input_dtype, device=device + ) + v_new = None + else: + q_rope_new = torch.randn( + num_new_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope_new = torch.randn( + num_new_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope_new = torch.randn( + num_new_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope_new = torch.randn( + num_new_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + v_new = torch.randn( + num_new_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + pos_ids_new = torch.arange( + num_existing_tokens, + num_existing_tokens + num_new_tokens, + device=device, + dtype=torch.int32, + ) + + # Build metadata for new tokens (continue appending to first request) + num_pages_new_needed = (total_tokens + page_size - 1) // page_size + kv_page_indptr_new = torch.tensor( + [0, num_pages_new_needed] + [num_pages_new_needed] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indices_new = torch.arange( + num_pages_new_needed, dtype=torch.int32, device=device + ) + # For continuation, positions start at num_existing_tokens + batch_indices_new = torch.zeros(num_new_tokens, device=device, dtype=torch.int32) + positions_new = torch.arange( + num_existing_tokens, + num_existing_tokens + num_new_tokens, + device=device, + dtype=torch.int32, + ) + + # Snapshot existing cache for later comparison + if attention_type == "mla": + ckv_cache_before = ckv_cache.clone() + kpe_cache_before = kpe_cache.clone() + else: + k_cache_before = k_cache.clone() + v_cache_before = v_cache.clone() + + # Append new tokens + if attention_type == "mla": + q_rope_out_new, q_nope_out_new = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_new, + k_rope_new, + q_nope_new, + k_nope_new, + None, + rope_ref.cos_sin_cache, + pos_ids_new, + (ckv_cache, kpe_cache), + kv_page_indices_new, + kv_page_indptr_new, + batch_indices_new, + positions_new, + page_size=page_size, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + else: + q_rope_out_new, q_nope_out_new = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_new, + k_rope_new, + q_nope_new, + k_nope_new, + v_new, + rope_ref.cos_sin_cache, + pos_ids_new, + (k_cache, v_cache), + kv_page_indices_new, + kv_page_indptr_new, + batch_indices_new, + positions_new, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + + # Verify Q outputs for new tokens + q_in_new = torch.cat([q_rope_new, q_nope_new], dim=-1) + k_in_new = torch.cat([k_rope_new, k_nope_new], dim=-1) + q_out_f16_ref_new, k_out_f16_ref_new = rope_ref.forward_native( + pos_ids_new, q_in_new, k_in_new + ) + q_out_f8_ref_new = q_out_f16_ref_new.to(quant_dtype) + k_out_f8_ref_new = k_out_f16_ref_new.to(quant_dtype) + + torch.testing.assert_close( + q_out_f8_ref_new[..., :rope_dim].float(), + q_rope_out_new.float(), + rtol=2e-1, + atol=1e-2, + ) + torch.testing.assert_close( + q_out_f8_ref_new[..., rope_dim:].float(), + q_nope_out_new.float(), + rtol=2e-1, + atol=1e-2, + ) + + # FP8 tolerances + if quant_dtype == torch.float8_e4m3fn: + rtol_val, atol_val = 0.25, 0.5 + else: + rtol_val, atol_val = 0.25, 1.0 + + # Verify existing cache entries remain unchanged + if attention_type == "mla": + # Check that entries before num_existing_tokens are unchanged + for i in range(num_existing_tokens): + b = batch_indices_existing[i].item() + pos = positions_existing[i].item() + page_iter = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) // page_size + entry_idx = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) % page_size + page_idx = kv_page_indices_existing[page_iter].item() + torch.testing.assert_close( + ckv_cache[page_idx, entry_idx, :].float(), + ckv_cache_before[page_idx, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing CKV cache entry {i} was modified", + ) + torch.testing.assert_close( + kpe_cache[page_idx, entry_idx, :].float(), + kpe_cache_before[page_idx, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing KPE cache entry {i} was modified", + ) + else: + for i in range(num_existing_tokens): + b = batch_indices_existing[i].item() + pos = positions_existing[i].item() + page_iter = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) // page_size + entry_idx = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) % page_size + page_idx = kv_page_indices_existing[page_iter].item() + if kv_layout == "NHD": + torch.testing.assert_close( + k_cache[page_idx, entry_idx, :, :].float(), + k_cache_before[page_idx, entry_idx, :, :].float(), + rtol=0, + atol=0, + msg=f"Existing K cache entry {i} was modified", + ) + torch.testing.assert_close( + v_cache[page_idx, entry_idx, :, :].float(), + v_cache_before[page_idx, entry_idx, :, :].float(), + rtol=0, + atol=0, + msg=f"Existing V cache entry {i} was modified", + ) + else: # HND + torch.testing.assert_close( + k_cache[page_idx, :, entry_idx, :].float(), + k_cache_before[page_idx, :, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing K cache entry {i} was modified", + ) + torch.testing.assert_close( + v_cache[page_idx, :, entry_idx, :].float(), + v_cache_before[page_idx, :, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing V cache entry {i} was modified", + ) + + # Verify new cache entries are correct + if attention_type == "mla": + k_rope_ref_new = k_out_f8_ref_new[..., :rope_dim] + k_nope_ref_new = k_out_f8_ref_new[..., rope_dim:] + + for i in range(num_new_tokens): + b = batch_indices_new[i].item() + pos = positions_new[i].item() + page_iter = (kv_page_indptr_new[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr_new[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices_new[page_iter].item() + torch.testing.assert_close( + ckv_cache[page_idx, entry_idx, :].float(), + k_nope_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + torch.testing.assert_close( + kpe_cache[page_idx, entry_idx, :].float(), + k_rope_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + else: + quant_scale_kv = 1.0 + v_ref_tokens_new = (v_new * quant_scale_kv).to(quant_dtype) + + for i in range(num_new_tokens): + b = batch_indices_new[i].item() + pos = positions_new[i].item() + page_iter = (kv_page_indptr_new[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr_new[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices_new[page_iter].item() + if kv_layout == "NHD": + torch.testing.assert_close( + k_cache[page_idx, entry_idx, :, :].float(), + k_out_f8_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + torch.testing.assert_close( + v_cache[page_idx, entry_idx, :, :].float(), + v_ref_tokens_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + else: # HND + torch.testing.assert_close( + k_cache[page_idx, :, entry_idx, :].float(), + k_out_f8_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + torch.testing.assert_close( + v_cache[page_idx, :, entry_idx, :].float(), + v_ref_tokens_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + + @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) From 20e123b069d27c9dfcd762e50601b6b456fa2caa Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Tue, 4 Nov 2025 12:00:18 -0800 Subject: [PATCH 03/13] align fake op registration with the custom op name --- flashinfer/rope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 72fa3d2ad9..02dcfe502c 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -296,7 +296,7 @@ def _rope_quantize_fp8_append_paged_kv_cache( ) -@register_fake_op("flashinfer::rope_quantize_fp8_append_paged_kv_cache") +@register_fake_op("flashinfer::rope_quantize_append_paged_kv_cache") def _fake_rope_quantize_fp8_append_paged_kv_cache( q_rope_in: torch.Tensor, k_rope_in: torch.Tensor, From 58a96f23c3cde4829c73d3100a31262c540a2f74 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Tue, 4 Nov 2025 12:13:53 -0800 Subject: [PATCH 04/13] only 4D k_cache/v_cache --- flashinfer/rope.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 02dcfe502c..c58599a982 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -1627,9 +1627,9 @@ def rope_quantize_fp8_append_paged_kv_cache( f"GQA/MHA cache dtype mismatch: expected {quantize_dtype}, " f"got k={k_cache.dtype}, v={v_cache.dtype}" ) - if k_cache.ndim not in [4, 5] or v_cache.ndim not in [4, 5]: + if k_cache.ndim != 4 or v_cache.ndim != 4: raise ValueError( - f"GQA/MHA cache must be 4D or 5D, got k={k_cache.ndim}D, v={v_cache.ndim}D" + f"GQA/MHA cache must be 4D, got k={k_cache.ndim}D, v={v_cache.ndim}D" ) # Create dummy tensors for MLA cache (not used) ckv_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) From 034be6a978bf920f7d9c013a1ea266c4c74b1275 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Tue, 4 Nov 2025 12:25:39 -0800 Subject: [PATCH 05/13] add check: GQA/MHA expects a V tensor,but got None. also get rid of docstring examples; too long and too complicated. Maybe just reference test code instead as an example? --- flashinfer/rope.py | 63 +++++----------------------------------------- 1 file changed, 6 insertions(+), 57 deletions(-) diff --git a/flashinfer/rope.py b/flashinfer/rope.py index c58599a982..67bd04bef6 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -1501,63 +1501,6 @@ def rope_quantize_fp8_append_paged_kv_cache( Quantized query tensors: (q_rope_out, q_nope_out). K/V are written directly to the paged cache and not returned. - Examples - -------- - MLA example: - - >>> import torch - >>> import flashinfer - >>> # MLA setup: 2D K tensors - >>> num_tokens, num_qo_heads, rope_dim, no_rope_dim = 32, 128, 64, 512 - >>> q_rope = torch.randn(num_tokens, num_qo_heads, rope_dim, dtype=torch.float16, device="cuda") - >>> q_nope = torch.randn(num_tokens, num_qo_heads, no_rope_dim, dtype=torch.float16, device="cuda") - >>> k_rope = torch.randn(num_tokens, rope_dim, dtype=torch.float16, device="cuda") - >>> k_nope = torch.randn(num_tokens, no_rope_dim, dtype=torch.float16, device="cuda") - >>> # Allocate MLA paged cache - >>> max_pages, page_size = 10, 16 - >>> ckv_cache = torch.zeros(max_pages, page_size, no_rope_dim, dtype=torch.float8_e4m3fn, device="cuda") - >>> kpe_cache = torch.zeros(max_pages, page_size, rope_dim, dtype=torch.float8_e4m3fn, device="cuda") - >>> # Setup RoPE and metadata - >>> rope_emb = flashinfer.rope.FlashInferRotaryEmbedding(rope_dim + no_rope_dim, rope_dim, 4096, 10000, False, torch.float16, "cuda") - >>> pos_ids = torch.arange(num_tokens, device="cuda", dtype=torch.int32) - >>> kv_page_indices = torch.arange(max_pages, device="cuda", dtype=torch.int32) - >>> kv_page_indptr = torch.tensor([0, max_pages], device="cuda", dtype=torch.int32) - >>> batch_indices = torch.zeros(num_tokens, device="cuda", dtype=torch.int32) - >>> positions = torch.arange(num_tokens, device="cuda", dtype=torch.int32) - >>> # Fused call - >>> q_rope_out, q_nope_out = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( - ... q_rope, k_rope, q_nope, k_nope, None, - ... rope_emb.cos_sin_cache, pos_ids, - ... (ckv_cache, kpe_cache), - ... kv_page_indices, kv_page_indptr, batch_indices, positions, - ... is_neox=False, quantize_dtype=torch.float8_e4m3fn, - ... page_size=page_size - ... ) - - GQA example: - - >>> # GQA setup: 3D K/V tensors - >>> num_tokens, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim = 32, 32, 8, 64, 64 - >>> head_dim = rope_dim + no_rope_dim - >>> q_rope = torch.randn(num_tokens, num_qo_heads, rope_dim, dtype=torch.float16, device="cuda") - >>> q_nope = torch.randn(num_tokens, num_qo_heads, no_rope_dim, dtype=torch.float16, device="cuda") - >>> k_rope = torch.randn(num_tokens, num_kv_heads, rope_dim, dtype=torch.float16, device="cuda") - >>> k_nope = torch.randn(num_tokens, num_kv_heads, no_rope_dim, dtype=torch.float16, device="cuda") - >>> v = torch.randn(num_tokens, num_kv_heads, head_dim, dtype=torch.float16, device="cuda") - >>> # Allocate GQA paged cache (NHD layout) - >>> max_pages, page_size = 10, 16 - >>> k_cache = torch.zeros(max_pages, page_size, num_kv_heads, head_dim, dtype=torch.float8_e4m3fn, device="cuda") - >>> v_cache = torch.zeros(max_pages, page_size, num_kv_heads, head_dim, dtype=torch.float8_e4m3fn, device="cuda") - >>> # Fused call - >>> q_rope_out, q_nope_out = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( - ... q_rope, k_rope, q_nope, k_nope, v, - ... rope_emb.cos_sin_cache, pos_ids, - ... (k_cache, v_cache), - ... kv_page_indices, kv_page_indptr, batch_indices, positions, - ... is_neox=False, quantize_dtype=torch.float8_e4m3fn, - ... page_size=page_size, kv_layout="NHD" - ... ) - Notes ----- - Architecture detection: Automatically distinguishes MLA (2D K tensors) from GQA/MHA (3D K tensors). @@ -1622,6 +1565,12 @@ def rope_quantize_fp8_append_paged_kv_cache( # GQA/MHA: Expect (k_cache, v_cache) k_cache = cache_0 v_cache = cache_1 + # Validate V input is provided for GQA/MHA + if v is None: + raise ValueError( + "GQA/MHA expects a V tensor, but got None. " + "Only MLA uses None for V (compressed KV representation)." + ) if k_cache.dtype != quantize_dtype or v_cache.dtype != quantize_dtype: raise ValueError( f"GQA/MHA cache dtype mismatch: expected {quantize_dtype}, " From 368dc2befae3d64c4a165432f98e98d1c58c6f06 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Wed, 12 Nov 2025 12:00:20 -0800 Subject: [PATCH 06/13] unsqueeze mla into dim 3 to match mha/gqa --- csrc/rope.cu | 39 +++++++++++++++------------------- flashinfer/rope.py | 7 +++++- include/flashinfer/pos_enc.cuh | 4 ++-- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/csrc/rope.cu b/csrc/rope.cu index 4da3d49125..40388d9412 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -477,17 +477,21 @@ void rope_quantize_append_paged_kv_cache( CHECK_DIM(3, q_rope_out); CHECK_DIM(3, q_nope_out); - // Detect architecture: MLA (2D K) vs GQA/MHA (3D K) - bool is_mla = (k_rope_in.ndim() == 2); + // Detect architecture based on cache presence/layout (not K dimensionality) + QKVLayout kv_layout = QKVLayout(kv_layout_code); + bool has_mla_caches = (ckv_cache.data_ptr() != nullptr && kpe_cache.data_ptr() != nullptr); + bool has_gqa_caches = (k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr); + bool is_mla = has_mla_caches && !has_gqa_caches; uint32_t num_kv_heads; uint32_t batch_size = kv_indptr.size(0) - 1; - QKVLayout kv_layout = QKVLayout(kv_layout_code); + // Require 3D K tensors in both paths; for MLA head dim must be 1 + CHECK_DIM(3, k_rope_in); + CHECK_DIM(3, k_nope_in); if (is_mla) { - // MLA: K tensors are 2D - CHECK_DIM(2, k_rope_in); - CHECK_DIM(2, k_nope_in); num_kv_heads = 1; + TVM_FFI_ICHECK_EQ(k_rope_in.size(1), 1) << "MLA expects K rope head dim == 1"; + TVM_FFI_ICHECK_EQ(k_nope_in.size(1), 1) << "MLA expects K nope head dim == 1"; // V can be empty for MLA TVM_FFI_ICHECK(v_in.data_ptr() == nullptr || v_in.size(0) == 0) << "MLA should not have V input (or it should be empty)"; @@ -499,9 +503,7 @@ void rope_quantize_append_paged_kv_cache( TVM_FFI_ICHECK_EQ(ckv_cache.size(2), no_rope_dim); TVM_FFI_ICHECK_EQ(kpe_cache.size(2), rope_dim); } else { - // GQA/MHA: K tensors are 3D - CHECK_DIM(3, k_rope_in); - CHECK_DIM(3, k_nope_in); + // GQA/MHA validation num_kv_heads = k_rope_in.size(1); TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads); // V is required for GQA/MHA @@ -526,23 +528,16 @@ void rope_quantize_append_paged_kv_cache( const uint32_t q_nope_out_stride_n = q_nope_out.stride(0); const uint32_t q_nope_out_stride_h = q_nope_out.stride(1); - // Extract K strides (architecture dependent) + // Extract K strides uint32_t k_rope_in_stride, k_nope_in_stride; uint32_t k_rope_in_stride_h, k_nope_in_stride_h; uint32_t v_in_stride = 0, v_in_stride_h = 0; - if (is_mla) { - // MLA: 2D K tensors - k_rope_in_stride = k_rope_in.stride(0); - k_nope_in_stride = k_nope_in.stride(0); - k_rope_in_stride_h = k_rope_in_stride; // Same as batch stride for 2D - k_nope_in_stride_h = k_nope_in_stride; - } else { - // GQA/MHA: 3D K tensors - k_rope_in_stride = k_rope_in.stride(0); - k_rope_in_stride_h = k_rope_in.stride(1); - k_nope_in_stride = k_nope_in.stride(0); - k_nope_in_stride_h = k_nope_in.stride(1); + k_rope_in_stride = k_rope_in.stride(0); + k_nope_in_stride = k_nope_in.stride(0); + k_rope_in_stride_h = k_rope_in.stride(1); + k_nope_in_stride_h = k_nope_in.stride(1); + if (!is_mla) { v_in_stride = v_in.stride(0); v_in_stride_h = v_in.stride(1); } diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 67bd04bef6..7480327681 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -1531,8 +1531,13 @@ def rope_quantize_fp8_append_paged_kv_cache( if q_nope_out is None: q_nope_out = torch.empty_like(q_nope, dtype=quantize_dtype) - # Handle V input for MLA (create empty dummy tensor, not used) + # Handle MLA normalization and V (create empty dummy tensor, not used) if is_mla: + # Normalize MLA K tensors to 3D (nnz, 1, dim) so C++ binding can always assume 3D + if k_rope.ndim == 2: + k_rope = k_rope.unsqueeze(1) + if k_nope.ndim == 2: + k_nope = k_nope.unsqueeze(1) if v is None: v = torch.empty(0, dtype=q_rope.dtype, device=q_rope.device) else: diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 478cb39d02..e0b23dfc17 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -385,7 +385,7 @@ __global__ void RopeQuantizeKernel( // 2. if not interleave // - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)] // - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)] - if ((tx * vec_size < rope_dim) and (by < k_rope_end)) { + if ((tx * vec_size < rope_dim) && (by < k_rope_end)) { int sin_offset = rope_dim / 2; int vec_idx; if constexpr (interleave) { @@ -796,7 +796,7 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( const int half_rope_dim = rope_dim / 2; // Load cos/sin for RoPE processing blocks only - if ((tx * vec_size < rope_dim) and (by < k_rope_end)) { + if ((tx * vec_size < rope_dim) && (by < k_rope_end)) { int sin_offset = rope_dim / 2; int vec_idx; if constexpr (interleave) { From 2556467d98400a860d0a1538fb1f1d20cc7201a4 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Wed, 12 Nov 2025 12:10:24 -0800 Subject: [PATCH 07/13] paramaterize page size 16,32 for testing --- tests/attention/test_rope.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index 917adba3ba..3b422b6c15 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -511,6 +511,7 @@ def test_generalized_rope_quantize( @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +@pytest.mark.parametrize("page_size", [16, 32]) def test_generalized_rope_quantize_append_kv_cache( attention_type, num_qo_heads, @@ -522,6 +523,7 @@ def test_generalized_rope_quantize_append_kv_cache( quant_dtype, enable_pdl, kv_layout, + page_size, ): device = "cuda:0" # Fixed seed for reproducibility @@ -530,7 +532,6 @@ def test_generalized_rope_quantize_append_kv_cache( torch.cuda.manual_seed_all(0) head_dim = rope_dim + no_rope_dim - page_size = 16 batch_size = 4 # Build inputs following the same pattern used elsewhere @@ -811,6 +812,7 @@ def test_generalized_rope_quantize_append_kv_cache( @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("enable_pdl", [True, False]) @pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +@pytest.mark.parametrize("page_size", [16, 32]) def test_rope_quantize_fp8_append_paged_kv_cache_decode( attention_type, num_qo_heads, @@ -823,6 +825,7 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( quant_dtype, enable_pdl, kv_layout, + page_size, ): """Test append to non-empty cache (decode/continuation scenario).""" device = "cuda:0" @@ -831,7 +834,6 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( torch.cuda.manual_seed_all(42) head_dim = rope_dim + no_rope_dim - page_size = 16 batch_size = 2 # Step 1: Pre-populate cache with existing tokens From 401bab3dc363c21606f9c2594023851a7a54fa56 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Wed, 12 Nov 2025 16:16:18 -0800 Subject: [PATCH 08/13] when no_rope_dim, optional None for nope tensors --- flashinfer/rope.py | 49 +++++++++-- tests/attention/test_rope.py | 152 +++++++++++++++++++++++++---------- 2 files changed, 151 insertions(+), 50 deletions(-) diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 7480327681..dea6995bcf 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -1285,8 +1285,8 @@ def mla_rope_quantize_fp8( def rope_quantize_fp8( q_rope: torch.Tensor, k_rope: torch.Tensor, - q_nope: torch.Tensor, - k_nope: torch.Tensor, + q_nope: Optional[torch.Tensor], + k_nope: Optional[torch.Tensor], cos_sin_cache: torch.Tensor, pos_ids: torch.Tensor, is_neox: bool = True, @@ -1313,12 +1313,12 @@ def rope_quantize_fp8( k_rope : torch.Tensor Key tensor (rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, rope_dim)``. For MLA: ``(nnz, rope_dim)``. Must be float16 or bfloat16. - q_nope : torch.Tensor + q_nope : Optional[torch.Tensor] Query tensor (non-rotary dimensions), shape: ``(nnz, num_qo_heads, no_rope_dim)``. - Must be float16 or bfloat16. - k_nope : torch.Tensor + If ``None``, treated as zero-dim: a size-0 tensor will be created internally. + k_nope : Optional[torch.Tensor] Key tensor (non-rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, no_rope_dim)``. - For MLA: ``(nnz, no_rope_dim)``. Must be float16 or bfloat16. + For MLA: ``(nnz, no_rope_dim)``. If ``None``, treated as zero-dim and created internally. cos_sin_cache : torch.Tensor Precomputed cosine and sine values, shape: ``(max_seq_len, rope_dim)``. First half contains cosine values, second half contains sine values. Must be float32. @@ -1353,6 +1353,23 @@ def rope_quantize_fp8( if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") + # Allow None for nope tensors and normalize to size-0 tensors with correct shapes + nnz = q_rope.shape[0] + num_qo_heads = q_rope.shape[1] + is_mla = k_rope.ndim == 2 + num_kv_heads = 1 if is_mla else k_rope.shape[1] + if q_nope is None: + q_nope = torch.empty( + nnz, num_qo_heads, 0, dtype=q_rope.dtype, device=q_rope.device + ) + if k_nope is None: + if is_mla: + k_nope = torch.empty(nnz, 0, dtype=k_rope.dtype, device=k_rope.device) + else: + k_nope = torch.empty( + nnz, num_kv_heads, 0, dtype=k_rope.dtype, device=k_rope.device + ) + # Infer quantize_dtype from output tensors or default to float8_e4m3fn if quantize_dtype is None: for out in (q_rope_out, k_rope_out, q_nope_out, k_nope_out): @@ -1407,8 +1424,8 @@ def rope_quantize_fp8( def rope_quantize_fp8_append_paged_kv_cache( q_rope: torch.Tensor, k_rope: torch.Tensor, - q_nope: torch.Tensor, - k_nope: torch.Tensor, + q_nope: Optional[torch.Tensor], + k_nope: Optional[torch.Tensor], v: Optional[torch.Tensor], cos_sin_cache: torch.Tensor, pos_ids: torch.Tensor, @@ -1516,6 +1533,22 @@ def rope_quantize_fp8_append_paged_kv_cache( # Detect architecture is_mla = k_rope.ndim == 2 + # Allow None for nope tensors and normalize to size-0 tensors with correct shapes + nnz = q_rope.shape[0] + num_qo_heads = q_rope.shape[1] + if q_nope is None: + q_nope = torch.empty( + nnz, num_qo_heads, 0, dtype=q_rope.dtype, device=q_rope.device + ) + if k_nope is None: + if is_mla: + k_nope = torch.empty(nnz, 0, dtype=k_rope.dtype, device=k_rope.device) + else: + num_kv_heads = k_rope.shape[1] + k_nope = torch.empty( + nnz, num_kv_heads, 0, dtype=k_rope.dtype, device=k_rope.device + ) + # Infer quantize_dtype from output tensors or default if quantize_dtype is None: if q_rope_out is not None: diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index 3b422b6c15..8e694088e5 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -536,29 +536,45 @@ def test_generalized_rope_quantize_append_kv_cache( # Build inputs following the same pattern used elsewhere if attention_type == "mla": - # Q: (N, Hq, *), K: 2D (N, *)x + # Q: (N, Hq, *), K: 2D (N, *) q_rope = torch.randn( num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device ) - q_nope = torch.randn( - num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + q_nope = ( + None + if no_rope_dim == 0 + else torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) ) k_rope = torch.randn(num_tokens, rope_dim, dtype=input_dtype, device=device) - k_nope = torch.randn(num_tokens, no_rope_dim, dtype=input_dtype, device=device) + k_nope = ( + None + if no_rope_dim == 0 + else torch.randn(num_tokens, no_rope_dim, dtype=input_dtype, device=device) + ) v = None else: # GQA/MHA: K/V are 3D q_rope = torch.randn( num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device ) - q_nope = torch.randn( - num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + q_nope = ( + None + if no_rope_dim == 0 + else torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) ) k_rope = torch.randn( num_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device ) - k_nope = torch.randn( - num_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + k_nope = ( + None + if no_rope_dim == 0 + else torch.randn( + num_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) ) v = torch.randn( num_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device @@ -699,9 +715,9 @@ def test_generalized_rope_quantize_append_kv_cache( enable_pdl=enable_pdl, ) ) - # Compute reference output - q_in = torch.cat([q_rope, q_nope], dim=-1) - k_in = torch.cat([k_rope, k_nope], dim=-1) + # Compute reference output (handle None for no_rope_dim == 0) + q_in = q_rope if q_nope is None else torch.cat([q_rope, q_nope], dim=-1) + k_in = k_rope if k_nope is None else torch.cat([k_rope, k_nope], dim=-1) q_out_f16_ref, k_out_f16_ref = rope_ref.forward_native(pos_ids, q_in, k_in) q_out_f8_ref, k_out_f8_ref = map( lambda x: x.to(quant_dtype), @@ -845,18 +861,26 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( dtype=input_dtype, device=device, ) - q_nope_existing = torch.randn( - num_existing_tokens, - num_qo_heads, - no_rope_dim, - dtype=input_dtype, - device=device, + q_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) ) k_rope_existing = torch.randn( num_existing_tokens, rope_dim, dtype=input_dtype, device=device ) - k_nope_existing = torch.randn( - num_existing_tokens, no_rope_dim, dtype=input_dtype, device=device + k_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, no_rope_dim, dtype=input_dtype, device=device + ) ) v_existing = None else: @@ -867,12 +891,16 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( dtype=input_dtype, device=device, ) - q_nope_existing = torch.randn( - num_existing_tokens, - num_qo_heads, - no_rope_dim, - dtype=input_dtype, - device=device, + q_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) ) k_rope_existing = torch.randn( num_existing_tokens, @@ -881,12 +909,16 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( dtype=input_dtype, device=device, ) - k_nope_existing = torch.randn( - num_existing_tokens, - num_kv_heads, - no_rope_dim, - dtype=input_dtype, - device=device, + k_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_kv_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) ) v_existing = torch.randn( num_existing_tokens, @@ -1036,28 +1068,56 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( q_rope_new = torch.randn( num_new_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device ) - q_nope_new = torch.randn( - num_new_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + q_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) ) k_rope_new = torch.randn( num_new_tokens, rope_dim, dtype=input_dtype, device=device ) - k_nope_new = torch.randn( - num_new_tokens, no_rope_dim, dtype=input_dtype, device=device + k_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, no_rope_dim, dtype=input_dtype, device=device + ) ) v_new = None else: q_rope_new = torch.randn( num_new_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device ) - q_nope_new = torch.randn( - num_new_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + q_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) ) k_rope_new = torch.randn( num_new_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device ) - k_nope_new = torch.randn( - num_new_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + k_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_kv_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) ) v_new = torch.randn( num_new_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device @@ -1146,9 +1206,17 @@ def test_rope_quantize_fp8_append_paged_kv_cache_decode( ) ) - # Verify Q outputs for new tokens - q_in_new = torch.cat([q_rope_new, q_nope_new], dim=-1) - k_in_new = torch.cat([k_rope_new, k_nope_new], dim=-1) + # Verify Q outputs for new tokens (handle None for no_rope_dim == 0) + q_in_new = ( + q_rope_new + if q_nope_new is None + else torch.cat([q_rope_new, q_nope_new], dim=-1) + ) + k_in_new = ( + k_rope_new + if k_nope_new is None + else torch.cat([k_rope_new, k_nope_new], dim=-1) + ) q_out_f16_ref_new, k_out_f16_ref_new = rope_ref.forward_native( pos_ids_new, q_in_new, k_in_new ) From 49b70ec1b180124edf003d0f0045440ceb20f3a5 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Wed, 12 Nov 2025 20:06:29 -0800 Subject: [PATCH 09/13] benchmarkign script revise use ncu --- .../bench_rope_quantize_fp8_append_cache.py | 479 ++++++++++++++++++ 1 file changed, 479 insertions(+) create mode 100644 benchmarks/bench_rope_quantize_fp8_append_cache.py diff --git a/benchmarks/bench_rope_quantize_fp8_append_cache.py b/benchmarks/bench_rope_quantize_fp8_append_cache.py new file mode 100644 index 0000000000..8d1356d9cd --- /dev/null +++ b/benchmarks/bench_rope_quantize_fp8_append_cache.py @@ -0,0 +1,479 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import sys +import csv +import subprocess +import argparse +import flashinfer +import numpy as np +import torch +from flashinfer.testing.utils import bench_gpu_time_with_cudagraph + +# Add the project root to Python path to import test helpers +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from tests.test_helpers.rope_reference import RotaryEmbedding + + +def benchmark_config( + config_name, + num_tokens, + batch_size=4, + page_size=16, + enable_pdl=False, + single_run=False, +): + """Benchmark a specific attention configuration with paged KV cache append.""" + input_dtype = torch.bfloat16 + device = "cuda" + quant_dtype = torch.float8_e4m3fn + + # Configuration-specific parameters + if config_name == "mla": + # MLA: DeepSeek-style multi-latent attention + num_qo_heads, num_kv_heads = 128, 1 + rope_dim, no_rope_dim = 64, 512 + elif config_name == "gqa": + # GQA: Grouped-query attention (e.g., Llama-style) + num_qo_heads, num_kv_heads = 32, 8 + rope_dim, no_rope_dim = 64, 64 + elif config_name == "mha": + # MHA: Standard multi-head attention + num_qo_heads, num_kv_heads = 32, 32 + rope_dim, no_rope_dim = 64, 64 + else: + raise ValueError(f"Unknown config: {config_name}") + + head_dim = rope_dim + no_rope_dim + + # Create input tensors + if config_name == "mla": + # MLA: 2D K tensors (shared) + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn(num_tokens, rope_dim, dtype=input_dtype, device=device) + k_nope = torch.randn(num_tokens, no_rope_dim, dtype=input_dtype, device=device) + v = None + else: + # GQA/MHA: 3D K/V tensors + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn( + num_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope = torch.randn( + num_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + v = torch.randn( + num_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # Create RoPE reference for cos/sin cache (ensure it covers this run) + max_seq_len = int(num_tokens) + rope_ref = RotaryEmbedding( + head_size=head_dim, + rotary_dim=rope_dim, + max_position_embeddings=max_seq_len, + base=10000, + is_neox_style=False, + dtype=input_dtype, + device=device, + ) + pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Build paged metadata (single request with all tokens) + kv_append_length = torch.tensor( + [num_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length, dim=0), + ] + ) + num_pages_per_req = torch.tensor( + [(num_tokens + page_size - 1) // page_size] + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(num_pages_per_req, dim=0), + ] + ) + kv_page_indices = torch.arange( + kv_page_indptr[-1].item(), dtype=torch.int32, device=device + ) + kv_last_page_len = torch.tensor( + [num_tokens % page_size if num_tokens % page_size != 0 else page_size] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + + # Get batch_indices and positions + seq_lens = flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size) + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, seq_lens, num_tokens + ) + + # Allocate caches + max_pages = kv_page_indptr[-1].item() + + if config_name == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + paged_kv_cache = (ckv_cache, kpe_cache) + else: + # GQA/MHA: use NHD layout + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + paged_kv_cache = (k_cache, v_cache) + + run_idx = 0 + + def execute(): + if single_run: + import torch.cuda.nvtx as nvtx + + nvtx.range_push("rope_append") + nonlocal run_idx + run_idx += 1 + + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope=q_rope, + k_rope=k_rope, + q_nope=q_nope, + k_nope=k_nope, + v=v, + cos_sin_cache=rope_ref.cos_sin_cache, + pos_ids=pos_ids, + paged_kv_cache=paged_kv_cache, + kv_indices=kv_page_indices, + kv_indptr=kv_page_indptr, + batch_indices=batch_indices, + positions=positions, + page_size=page_size, + kv_layout="NHD" if config_name != "mla" else "NHD", + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + if single_run: + # Ensure kernels complete inside the NVTX range for ncu filtering + torch.cuda.synchronize() + nvtx.range_pop() + + if single_run: + execute() + return None, None, None, None, None + measurements = bench_gpu_time_with_cudagraph(execute) + + # Calculate I/O bytes + # Inputs: q_rope, k_rope, q_nope, k_nope, v (if not MLA), cos_sin_cache, pos_ids + io_bytes = ( + q_rope.numel() * q_rope.element_size() + + k_rope.numel() * k_rope.element_size() + + q_nope.numel() * q_nope.element_size() + + k_nope.numel() * k_nope.element_size() + + rope_ref.cos_sin_cache.numel() * rope_ref.cos_sin_cache.element_size() + + pos_ids.numel() * pos_ids.element_size() + ) + + if v is not None: + io_bytes += v.numel() * v.element_size() + + # Outputs: q_rope_out, q_nope_out (FP8), cache writes (FP8) + io_bytes += ( + q_rope.numel() * torch.finfo(quant_dtype).bits // 8 + + q_nope.numel() * torch.finfo(quant_dtype).bits // 8 + ) + + if config_name == "mla": + # MLA writes to ckv_cache and kpe_cache + io_bytes += ( + num_tokens * no_rope_dim * torch.finfo(quant_dtype).bits // 8 + + num_tokens * rope_dim * torch.finfo(quant_dtype).bits // 8 + ) + else: + # GQA/MHA writes to k_cache and v_cache + io_bytes += ( + num_tokens * num_kv_heads * head_dim * torch.finfo(quant_dtype).bits // 8 + + num_tokens * num_kv_heads * head_dim * torch.finfo(quant_dtype).bits // 8 + ) + + # Calculate statistics + ms = np.median(measurements) + min_ms = np.percentile(measurements, 20) + max_ms = np.percentile(measurements, 80) + + # Calculate bandwidth in GB/s + bandwidth_gb_s = io_bytes / ms / 1e6 + + # Calculate TFLOPs (FP operations) + # RoPE: 6 FLOPs per dimension pair (2 muls + 1 sub for real, 2 muls + 1 add for imag) + # For Q: num_tokens * num_qo_heads * (rope_dim/2) pairs * 6 FLOPs + # For K: depends on architecture + q_flops = num_tokens * num_qo_heads * (rope_dim / 2) * 6 + + if config_name == "mla": + # MLA: K is 2D (no head dimension) + k_flops = num_tokens * (rope_dim / 2) * 6 + else: + # GQA/MHA: K is 3D (has head dimension) + k_flops = num_tokens * num_kv_heads * (rope_dim / 2) * 6 + + total_flops = q_flops + k_flops + tflops = ( + total_flops / ms / 1e9 + ) # TFLOPs (operations per ms = operations per second / 1e12) + + return ms, min_ms, max_ms, bandwidth_gb_s, tflops + + +def _run_ncu_and_get_bw_pct( + script_path, config_name, num_tokens, page_size, enable_pdl +): + cmd = [ + "ncu", + "--target-processes", + "all", + "--metrics", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed,gpu__compute_memory_throughput.avg.pct_of_peak_sustained_elapsed", + "--csv", + "--page", + "raw", + sys.executable, + script_path, + "--ncu-single", + "--config", + config_name, + "--num-tokens", + str(num_tokens), + "--page-size", + str(page_size), + "--enable-pdl", + str(int(enable_pdl)), + ] + try: + out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, text=True) + except (subprocess.CalledProcessError, FileNotFoundError) as e: + print(f"Warning: Nsight Compute not available or failed: {e}") + return -1.0 + # parse ncu output csv + target_metric_dram = "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed" + target_metric_compute = ( + "gpu__compute_memory_throughput.avg.pct_of_peak_sustained_elapsed" + ) + lines = [ + l + for l in out.splitlines() + if l.strip() and not l.startswith("#") and not l.startswith("==") + ] + reader = csv.reader(lines) + header = None + kernel_idx = None + dram_idx = None + compute_idx = None + rows = [] + for row in reader: + if not row: + continue + if header is None: + header = [c.strip() for c in row] + # Locate columns + try: + kernel_idx = header.index("Kernel Name") + except ValueError: + for i, c in enumerate(header): + if c.replace(" ", "").lower() == "kernelname": + kernel_idx = i + break + try: + dram_idx = header.index(target_metric_dram) + except ValueError: + dram_idx = None + try: + compute_idx = header.index(target_metric_compute) + except ValueError: + compute_idx = None + continue + rows.append(row) + if header is None or (dram_idx is None and compute_idx is None): + print("Warning: Unable to parse BW% from Nsight Compute output.") + return -1.0, -1.0 + # Only return the fused kernel's metrics; otherwise fail + fused_dram = None + fused_compute = None + for r in rows: + if len(r) <= max((kernel_idx or 0), dram_idx or 0, compute_idx or 0): + continue + kname = r[kernel_idx] if kernel_idx is not None else "" + kname_l = kname.lower() + # Match our fused kernel robustly + if ( + "ropequantizeappendpagedkvcachekernel".lower() in kname_l + or ("rope" in kname_l and "append" in kname_l) + or "ropequantizeappend" in kname_l + ): + try: + if dram_idx is not None and len(r) > dram_idx: + fused_dram = float(r[dram_idx]) + except Exception: + fused_dram = None + try: + if compute_idx is not None and len(r) > compute_idx: + fused_compute = float(r[compute_idx]) + except Exception: + fused_compute = None + break + if fused_dram is not None or fused_compute is not None: + return ( + fused_dram if fused_dram is not None else -1.0, + fused_compute if fused_compute is not None else -1.0, + ) + print("Warning: Unable to find fused kernel metric in Nsight Compute output.") + return -1.0, -1.0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ncu-single", action="store_true", help="Run a single execute() for ncu" + ) + parser.add_argument( + "--config", type=str, default="", help="Config name: mla/gqa/mha" + ) + parser.add_argument("--num-tokens", type=int, default=0) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--enable-pdl", type=int, default=0) + args, unknown = parser.parse_known_args() + + if args.ncu_single: + # Minimal single-run for ncu profiling + cfg = args.config or "mla" + ntok = int(args.num_tokens) + pgsz = int(args.page_size) + en_pdl = bool(int(args.enable_pdl)) + # Force a single execution path + benchmark_config(cfg, ntok, page_size=pgsz, enable_pdl=en_pdl, single_run=True) + sys.exit(0) + + # Get GPU information (for display only) + gpu_name = torch.cuda.get_device_name(0) + print(f"\nDetected GPU: {gpu_name}") + print() + + # Token counts to benchmark + token_counts = [1, 32, 128, 384, 768, 1024, 2048, 4096, 8192] + + # Helper function to print a table for a specific configuration + def print_config_table(config_name, config_desc): + page_size_to_benchmark = 32 + print(f"\n{'=' * 100}") + print(f" {config_name.upper()}: {config_desc}") + print(f"{'=' * 100}") + + # Only use ncu if available. if not available, skip BW% calculations + first_tokens = token_counts[-1] + try: + probe_dram_pct, probe_compute_pct = _run_ncu_and_get_bw_pct( + os.path.abspath(__file__), + config_name, + first_tokens, + page_size_to_benchmark, + False, + ) + show_bw_pct = (probe_dram_pct >= 0) and (probe_compute_pct >= 0) + except Exception as e: + print( + f"Warning: Skipping BW% calculations due to Nsight Compute not available or failed: {e}" + ) + show_bw_pct = False + + if show_bw_pct: + print( + f"{'Tokens':<10} {'Time (ms)':<12} {'BW (GB/s)':<12} {'BW% (DRAM)':<14} {'BW% (Compute)':<16} {'TFLOPs':<12}" + ) + print("-" * 80) + for num_tokens in token_counts: + ms, _, _, bw, tflops = benchmark_config( + config_name, num_tokens, page_size=page_size_to_benchmark + ) + dram_pct, compute_pct = _run_ncu_and_get_bw_pct( + os.path.abspath(__file__), + config_name, + num_tokens, + page_size_to_benchmark, + False, + ) + if dram_pct < 0 or compute_pct < 0: + # If individual row fails, fall back to skipping BW% entirely for consistency + show_bw_pct = False + break + print( + f"{num_tokens:<10} {ms:<12.5f} {bw:<12.2f} {dram_pct:<14.1f} {compute_pct:<16.1f} {tflops:<12.3f}" + ) + if not show_bw_pct: + print(f"{'Tokens':<10} {'Time (ms)':<12} {'BW (GB/s)':<12} {'TFLOPs':<12}") + print("-" * 50) + for num_tokens in token_counts: + ms, _, _, bw, tflops = benchmark_config( + config_name, num_tokens, page_size=page_size_to_benchmark + ) + print(f"{num_tokens:<10} {ms:<12.5f} {bw:<12.2f} {tflops:<12.3f}") + + # Print tables for each configuration + print_config_table("mla", "128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)") + print_config_table("gqa", "32 Q heads, 8 K heads, 64+64 dims (Llama-style)") + print_config_table("mha", "32 Q heads, 32 K heads, 64+64 dims (Standard)") + + print("\n" + "=" * 100) + print("Configuration details:") + print(" Page size: 32, Batch size: 4") + print(" Token range: 1 (single decode) → 8192 (large prefill)") + print(f" GPU: {gpu_name}") + print("=" * 100) From 3a9755458d034e9d9493f712e144804ae2ddaacf Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 13 Nov 2025 05:51:30 +0000 Subject: [PATCH 10/13] upd --- .../bench_rope_quantize_fp8_append_cache.py | 167 ++---------------- flashinfer/utils.py | 41 +++++ 2 files changed, 56 insertions(+), 152 deletions(-) diff --git a/benchmarks/bench_rope_quantize_fp8_append_cache.py b/benchmarks/bench_rope_quantize_fp8_append_cache.py index 8d1356d9cd..3119b9fef8 100644 --- a/benchmarks/bench_rope_quantize_fp8_append_cache.py +++ b/benchmarks/bench_rope_quantize_fp8_append_cache.py @@ -16,13 +16,12 @@ import os import sys -import csv -import subprocess import argparse import flashinfer import numpy as np import torch from flashinfer.testing.utils import bench_gpu_time_with_cudagraph +from flashinfer.utils import get_gpu_memory_bandwidth # Add the project root to Python path to import test helpers sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -274,111 +273,6 @@ def execute(): return ms, min_ms, max_ms, bandwidth_gb_s, tflops -def _run_ncu_and_get_bw_pct( - script_path, config_name, num_tokens, page_size, enable_pdl -): - cmd = [ - "ncu", - "--target-processes", - "all", - "--metrics", - "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed,gpu__compute_memory_throughput.avg.pct_of_peak_sustained_elapsed", - "--csv", - "--page", - "raw", - sys.executable, - script_path, - "--ncu-single", - "--config", - config_name, - "--num-tokens", - str(num_tokens), - "--page-size", - str(page_size), - "--enable-pdl", - str(int(enable_pdl)), - ] - try: - out = subprocess.check_output(cmd, stderr=subprocess.STDOUT, text=True) - except (subprocess.CalledProcessError, FileNotFoundError) as e: - print(f"Warning: Nsight Compute not available or failed: {e}") - return -1.0 - # parse ncu output csv - target_metric_dram = "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed" - target_metric_compute = ( - "gpu__compute_memory_throughput.avg.pct_of_peak_sustained_elapsed" - ) - lines = [ - l - for l in out.splitlines() - if l.strip() and not l.startswith("#") and not l.startswith("==") - ] - reader = csv.reader(lines) - header = None - kernel_idx = None - dram_idx = None - compute_idx = None - rows = [] - for row in reader: - if not row: - continue - if header is None: - header = [c.strip() for c in row] - # Locate columns - try: - kernel_idx = header.index("Kernel Name") - except ValueError: - for i, c in enumerate(header): - if c.replace(" ", "").lower() == "kernelname": - kernel_idx = i - break - try: - dram_idx = header.index(target_metric_dram) - except ValueError: - dram_idx = None - try: - compute_idx = header.index(target_metric_compute) - except ValueError: - compute_idx = None - continue - rows.append(row) - if header is None or (dram_idx is None and compute_idx is None): - print("Warning: Unable to parse BW% from Nsight Compute output.") - return -1.0, -1.0 - # Only return the fused kernel's metrics; otherwise fail - fused_dram = None - fused_compute = None - for r in rows: - if len(r) <= max((kernel_idx or 0), dram_idx or 0, compute_idx or 0): - continue - kname = r[kernel_idx] if kernel_idx is not None else "" - kname_l = kname.lower() - # Match our fused kernel robustly - if ( - "ropequantizeappendpagedkvcachekernel".lower() in kname_l - or ("rope" in kname_l and "append" in kname_l) - or "ropequantizeappend" in kname_l - ): - try: - if dram_idx is not None and len(r) > dram_idx: - fused_dram = float(r[dram_idx]) - except Exception: - fused_dram = None - try: - if compute_idx is not None and len(r) > compute_idx: - fused_compute = float(r[compute_idx]) - except Exception: - fused_compute = None - break - if fused_dram is not None or fused_compute is not None: - return ( - fused_dram if fused_dram is not None else -1.0, - fused_compute if fused_compute is not None else -1.0, - ) - print("Warning: Unable to find fused kernel metric in Nsight Compute output.") - return -1.0, -1.0 - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -403,8 +297,11 @@ def _run_ncu_and_get_bw_pct( sys.exit(0) # Get GPU information (for display only) + device = torch.device("cuda:0") gpu_name = torch.cuda.get_device_name(0) + gpu_peak_bandwidth = get_gpu_memory_bandwidth(device) print(f"\nDetected GPU: {gpu_name}") + print(f"Theoretical Peak Memory Bandwidth: {gpu_peak_bandwidth:.2f} GB/s") print() # Token counts to benchmark @@ -417,54 +314,18 @@ def print_config_table(config_name, config_desc): print(f" {config_name.upper()}: {config_desc}") print(f"{'=' * 100}") - # Only use ncu if available. if not available, skip BW% calculations - first_tokens = token_counts[-1] - try: - probe_dram_pct, probe_compute_pct = _run_ncu_and_get_bw_pct( - os.path.abspath(__file__), - config_name, - first_tokens, - page_size_to_benchmark, - False, - ) - show_bw_pct = (probe_dram_pct >= 0) and (probe_compute_pct >= 0) - except Exception as e: - print( - f"Warning: Skipping BW% calculations due to Nsight Compute not available or failed: {e}" + print( + f"{'Tokens':<10} {'Time (ms)':<12} {'BW (GB/s)':<12} {'BW% (Peak)':<14} {'TFLOPs':<12}" + ) + print("-" * 70) + for num_tokens in token_counts: + ms, _, _, bw, tflops = benchmark_config( + config_name, num_tokens, page_size=page_size_to_benchmark ) - show_bw_pct = False - - if show_bw_pct: + bw_pct = (bw / gpu_peak_bandwidth) * 100 print( - f"{'Tokens':<10} {'Time (ms)':<12} {'BW (GB/s)':<12} {'BW% (DRAM)':<14} {'BW% (Compute)':<16} {'TFLOPs':<12}" + f"{num_tokens:<10} {ms:<12.5f} {bw:<12.2f} {bw_pct:<14.1f} {tflops:<12.3f}" ) - print("-" * 80) - for num_tokens in token_counts: - ms, _, _, bw, tflops = benchmark_config( - config_name, num_tokens, page_size=page_size_to_benchmark - ) - dram_pct, compute_pct = _run_ncu_and_get_bw_pct( - os.path.abspath(__file__), - config_name, - num_tokens, - page_size_to_benchmark, - False, - ) - if dram_pct < 0 or compute_pct < 0: - # If individual row fails, fall back to skipping BW% entirely for consistency - show_bw_pct = False - break - print( - f"{num_tokens:<10} {ms:<12.5f} {bw:<12.2f} {dram_pct:<14.1f} {compute_pct:<16.1f} {tflops:<12.3f}" - ) - if not show_bw_pct: - print(f"{'Tokens':<10} {'Time (ms)':<12} {'BW (GB/s)':<12} {'TFLOPs':<12}") - print("-" * 50) - for num_tokens in token_counts: - ms, _, _, bw, tflops = benchmark_config( - config_name, num_tokens, page_size=page_size_to_benchmark - ) - print(f"{num_tokens:<10} {ms:<12.5f} {bw:<12.2f} {tflops:<12.3f}") # Print tables for each configuration print_config_table("mla", "128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)") @@ -476,4 +337,6 @@ def print_config_table(config_name, config_desc): print(" Page size: 32, Batch size: 4") print(" Token range: 1 (single decode) → 8192 (large prefill)") print(f" GPU: {gpu_name}") + print(f" Theoretical Peak Memory Bandwidth: {gpu_peak_bandwidth:.2f} GB/s") + print(" BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100") print("=" * 100) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 936d08380c..5c4c28642d 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -21,6 +21,7 @@ import torch import torch.version +import pynvml from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version @@ -254,6 +255,46 @@ def get_compute_capability(device: torch.device) -> Tuple[int, int]: return torch.cuda.get_device_capability(device.index) +@functools.cache +def get_gpu_memory_bandwidth(device: torch.device) -> float: + """ + Get GPU memory bandwidth in GB/s for the specified CUDA device. + + Args: + device: torch.device object, e.g., torch.device('cuda:0') + + Returns: + float: GPU memory bandwidth (GB/s) + + Raises: + ValueError: If device is not a CUDA device + """ + # Convert to torch.device object if string is passed + if isinstance(device, str): + device = torch.device(device) + + # Check if it's a CUDA device + if device.type != "cuda": + raise ValueError(f"Device must be a CUDA device, got {device}") + + # Get device index + device_index = device.index if device.index is not None else 0 + + # Use pynvml to get bandwidth + pynvml.nvmlInit() + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle) + mem_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_MEM) + + # Calculate theoretical peak bandwidth (GB/s) + bandwidth = (mem_clock * bus_width * 2) / 8 / 1000 + + return bandwidth + finally: + pynvml.nvmlShutdown() + + def _check_cached_qkv_data_type( q: torch.Tensor, k: torch.Tensor, dtype_q: torch.dtype, dtype_kv: torch.dtype ) -> None: From ec566bc28219e1e1522a8fee0bd2b1b0b28c33c2 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Thu, 13 Nov 2025 00:45:42 -0800 Subject: [PATCH 11/13] small fixes --- include/flashinfer/pos_enc.cuh | 227 ++++++++++++++++++++------------- 1 file changed, 135 insertions(+), 92 deletions(-) diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index e0b23dfc17..7901b71e22 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include "layout.cuh" #include "math.cuh" @@ -29,6 +30,30 @@ namespace flashinfer { +struct RopeQuantizeAppendPagedKVCacheParams { + uint32_t nnz; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t rope_dim; + uint32_t no_rope_dim; + size_t q_rope_in_stride_n; + size_t q_rope_in_stride_h; + size_t q_nope_in_stride_n; + size_t q_nope_in_stride_h; + size_t q_rope_out_stride_n; + size_t q_rope_out_stride_h; + size_t q_nope_out_stride_n; + size_t q_nope_out_stride_h; + size_t k_rope_in_stride; + size_t k_rope_in_stride_h; + size_t k_nope_in_stride; + size_t k_nope_in_stride_h; + size_t v_in_stride; + size_t v_in_stride_h; + float quant_scale_q; + float quant_scale_kv; +}; + /*! * \brief An enumeration class that defines different modes for applying RoPE * (Rotary Positional Embeddings). @@ -718,35 +743,6 @@ __global__ void BatchQKApplyRotaryKernel( } } -#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ - if (interleave) { \ - const bool INTERLEAVE = true; \ - __VA_ARGS__ \ - } else { \ - const bool INTERLEAVE = false; \ - __VA_ARGS__ \ - } - -#define DISPATCH_ROPE_DIM(rope_dim, vec_size, ...) \ - if (rope_dim == 16) { \ - constexpr uint32_t bdx = 16 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 32) { \ - constexpr uint32_t bdx = 32 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 64) { \ - constexpr uint32_t bdx = 64 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 128) { \ - constexpr uint32_t bdx = 128 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 256) { \ - constexpr uint32_t bdx = 256 / vec_size; \ - __VA_ARGS__ \ - } else { \ - FLASHINFER_ERROR("Unsupported rope_dim. Supported values: 16, 32, 64, 128, 256"); \ - } - /*! * \brief Unified CUDA kernel to apply RoPE, quantize to FP8, and append to paged cache. * @@ -754,18 +750,13 @@ __global__ void BatchQKApplyRotaryKernel( * Cache-only behaviors are selected with constexpr on the CacheT. */ template + typename QuantType, typename CacheT> __global__ void RopeQuantizeAppendPagedKVCacheKernel( DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, QuantType* q_rope_out, QuantType* q_nope_out, CacheT paged_kv_like, IdType* __restrict__ batch_indices, IdType* __restrict__ positions, - float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz, - uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rope_dim, uint32_t no_rope_dim, - size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n, - size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h, - size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, - size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h, - size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv) { + float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids, + const RopeQuantizeAppendPagedKVCacheParams params) { #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif @@ -773,6 +764,29 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( uint32_t by = blockIdx.y; uint32_t bdy = blockDim.y; + // Local aliases for params for readability + const uint32_t nnz = params.nnz; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t rope_dim = params.rope_dim; + const uint32_t no_rope_dim = params.no_rope_dim; + const size_t q_rope_in_stride_n = params.q_rope_in_stride_n; + const size_t q_rope_in_stride_h = params.q_rope_in_stride_h; + const size_t q_nope_in_stride_n = params.q_nope_in_stride_n; + const size_t q_nope_in_stride_h = params.q_nope_in_stride_h; + const size_t q_rope_out_stride_n = params.q_rope_out_stride_n; + const size_t q_rope_out_stride_h = params.q_rope_out_stride_h; + const size_t q_nope_out_stride_n = params.q_nope_out_stride_n; + const size_t q_nope_out_stride_h = params.q_nope_out_stride_h; + const size_t k_rope_in_stride = params.k_rope_in_stride; + const size_t k_rope_in_stride_h = params.k_rope_in_stride_h; + const size_t k_nope_in_stride = params.k_nope_in_stride; + const size_t k_nope_in_stride_h = params.k_nope_in_stride_h; + const size_t v_in_stride = params.v_in_stride; + const size_t v_in_stride_h = params.v_in_stride_h; + const float quant_scale_q = params.quant_scale_q; + const float quant_scale_kv = params.quant_scale_kv; + // Calculate flexible boundaries for block allocation uint32_t rope_chunk_size = rope_dim; uint32_t rope_chunks = (rope_dim + rope_chunk_size - 1) / rope_chunk_size; @@ -783,6 +797,9 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel( uint32_t k_rope_end = q_rope_end + num_kv_heads * rope_chunks; uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks; + // Deduce MLA vs GQA/MHA from CacheT + constexpr bool IS_MLA = std::is_same>::value; + vec_t cos, sin; if (bx * bdy + ty < nnz) { const uint32_t idx = bx * bdy + ty; @@ -975,11 +992,11 @@ cudaError_t RopeQuantize( FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - constexpr uint32_t vec_size = 32 / sizeof(DType); - // Use nested macros for runtime->compile-time dispatch for required constexpr values - DISPATCH_ROPE_DIM(rope_dim, vec_size, { + DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t bdx = ROPE_DIM / vec_size; uint32_t num_threads = 128U; uint32_t bdy = num_threads / bdx; uint32_t nblks_x = (nnz + bdy - 1) / bdy; @@ -1065,10 +1082,10 @@ cudaError_t RopeQuantizeAppendPagedKVCache( size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h, size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv, bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { - constexpr uint32_t vec_size = 32 / sizeof(DType); - - DISPATCH_ROPE_DIM(rope_dim, vec_size, { + DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t bdx = ROPE_DIM / vec_size; uint32_t num_threads = 128U; uint32_t bdy = num_threads / bdx; uint32_t nblks_x = (nnz + bdy - 1) / bdy; @@ -1094,30 +1111,42 @@ cudaError_t RopeQuantizeAppendPagedKVCache( config.attrs = attribute; config.numAttrs = 1; - auto kernel = - RopeQuantizeAppendPagedKVCacheKernel>; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( - &config, kernel, - // inputs - q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in, - // q outputs - q_rope_out, q_nope_out, - // cache + indices - paged_kv, batch_indices, positions, - // rope tables - cos_sin_cache, pos_ids, - // sizes - nnz, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, - // Q strides (in/out) - q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, - q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, - // K strides - k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, k_nope_in_stride_h, - // V strides - v_in_stride, v_in_stride_h, - // scales - quant_scale_q, quant_scale_kv)); + auto kernel = RopeQuantizeAppendPagedKVCacheKernel>; + RopeQuantizeAppendPagedKVCacheParams params; + params.nnz = nnz; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.rope_dim = rope_dim; + params.no_rope_dim = no_rope_dim; + params.q_rope_in_stride_n = q_rope_in_stride_n; + params.q_rope_in_stride_h = q_rope_in_stride_h; + params.q_nope_in_stride_n = q_nope_in_stride_n; + params.q_nope_in_stride_h = q_nope_in_stride_h; + params.q_rope_out_stride_n = q_rope_out_stride_n; + params.q_rope_out_stride_h = q_rope_out_stride_h; + params.q_nope_out_stride_n = q_nope_out_stride_n; + params.q_nope_out_stride_h = q_nope_out_stride_h; + params.k_rope_in_stride = k_rope_in_stride; + params.k_rope_in_stride_h = k_rope_in_stride_h; + params.k_nope_in_stride = k_nope_in_stride; + params.k_nope_in_stride_h = k_nope_in_stride_h; + params.v_in_stride = v_in_stride; + params.v_in_stride_h = v_in_stride_h; + params.quant_scale_q = quant_scale_q; + params.quant_scale_kv = quant_scale_kv; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, + // inputs + q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in, + // q outputs + q_rope_out, q_nope_out, + // cache + indices + paged_kv, batch_indices, positions, + // rope tables + cos_sin_cache, pos_ids, + // params + params)); }); }); @@ -1137,10 +1166,10 @@ cudaError_t RopeQuantizeAppendPagedMLACache( size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, size_t k_nope_in_stride, float quant_scale_q, float quant_scale_kv, bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { - constexpr uint32_t vec_size = 32 / sizeof(DType); - - DISPATCH_ROPE_DIM(rope_dim, vec_size, { + DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t bdx = ROPE_DIM / vec_size; uint32_t num_threads = 128U; uint32_t bdy = num_threads / bdx; uint32_t nblks_x = (nnz + bdy - 1) / bdy; @@ -1168,7 +1197,7 @@ cudaError_t RopeQuantizeAppendPagedMLACache( auto kernel = RopeQuantizeAppendPagedKVCacheKernel>; + paged_kv_mla_t>; // For MLA: pass v_in as nullptr, num_kv_heads=1, duplicate 2D K strides for head strides, and // 0 V strides DType* v_in_nullptr = nullptr; @@ -1176,27 +1205,41 @@ cudaError_t RopeQuantizeAppendPagedMLACache( size_t k_rope_in_stride_h_dup = k_rope_in_stride; size_t k_nope_in_stride_h_dup = k_nope_in_stride; size_t v_in_stride_zero = 0, v_in_stride_h_zero = 0; - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( - &config, kernel, - // inputs - q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in_nullptr, - // q outputs - q_rope_out, q_nope_out, - // cache + indices - paged_kv_mla, batch_indices, positions, - // rope tables - cos_sin_cache, pos_ids, - // sizes - nnz, num_qo_heads, num_kv_heads_1, rope_dim, no_rope_dim, - // Q strides (in/out) - q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, - q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, - // K strides (2D: duplicate for head stride) - k_rope_in_stride, k_rope_in_stride_h_dup, k_nope_in_stride, k_nope_in_stride_h_dup, - // V strides (unused for MLA) - v_in_stride_zero, v_in_stride_h_zero, - // scales - quant_scale_q, quant_scale_kv)); + RopeQuantizeAppendPagedKVCacheParams params; + params.nnz = nnz; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = 1u; + params.rope_dim = rope_dim; + params.no_rope_dim = no_rope_dim; + params.q_rope_in_stride_n = q_rope_in_stride_n; + params.q_rope_in_stride_h = q_rope_in_stride_h; + params.q_nope_in_stride_n = q_nope_in_stride_n; + params.q_nope_in_stride_h = q_nope_in_stride_h; + params.q_rope_out_stride_n = q_rope_out_stride_n; + params.q_rope_out_stride_h = q_rope_out_stride_h; + params.q_nope_out_stride_n = q_nope_out_stride_n; + params.q_nope_out_stride_h = q_nope_out_stride_h; + params.k_rope_in_stride = k_rope_in_stride; + params.k_rope_in_stride_h = k_rope_in_stride_h_dup; + params.k_nope_in_stride = k_nope_in_stride; + params.k_nope_in_stride_h = k_nope_in_stride_h_dup; + params.v_in_stride = 0; + params.v_in_stride_h = 0; + params.quant_scale_q = quant_scale_q; + params.quant_scale_kv = quant_scale_kv; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, + // inputs + q_rope_in, k_rope_in, q_nope_in, k_nope_in, + v_in_nullptr, + // q outputs + q_rope_out, q_nope_out, + // cache + indices + paged_kv_mla, batch_indices, positions, + // rope tables + cos_sin_cache, pos_ids, + // params + params)); }); }); From 72bd59e39ae0c5371fbd8e5e932813683e434d3b Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Thu, 13 Nov 2025 10:59:28 -0800 Subject: [PATCH 12/13] forgot to add utils --- include/flashinfer/utils.cuh | 46 ++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 5b26d7beaf..5effc29fe2 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -201,6 +201,52 @@ } \ } +// convert interleave to compile-time constant +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, ...) \ + switch (rope_dim) { \ + case 16: { \ + constexpr uint32_t ROPE_DIM = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 32: { \ + constexpr uint32_t ROPE_DIM = 32; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: { \ + constexpr uint32_t ROPE_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr uint32_t ROPE_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 256: { \ + constexpr uint32_t ROPE_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported ROPE_DIM: " << rope_dim; \ + err_msg << ". Supported values: 16, 32, 64, 128, 256"; \ + err_msg << " in DISPATCH_ROPE_DIM"; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + #define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ switch (pos_encoding_mode) { \ case PosEncodingMode::kNone: { \ From 0f87afd133cf5814cbed744b6624fe21cd17d2c8 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Fri, 14 Nov 2025 12:45:35 -0800 Subject: [PATCH 13/13] fix cicd test? --- include/flashinfer/utils.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 5effc29fe2..0471bd1081 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -204,10 +204,10 @@ // convert interleave to compile-time constant #define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ if (interleave) { \ - const bool INTERLEAVE = true; \ + constexpr bool INTERLEAVE = true; \ __VA_ARGS__ \ } else { \ - const bool INTERLEAVE = false; \ + constexpr bool INTERLEAVE = false; \ __VA_ARGS__ \ }