diff --git a/benchmarks/bench_rope_quantize_fp8_append_cache.py b/benchmarks/bench_rope_quantize_fp8_append_cache.py new file mode 100644 index 0000000000..3119b9fef8 --- /dev/null +++ b/benchmarks/bench_rope_quantize_fp8_append_cache.py @@ -0,0 +1,342 @@ +""" +Copyright (c) 2024 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import sys +import argparse +import flashinfer +import numpy as np +import torch +from flashinfer.testing.utils import bench_gpu_time_with_cudagraph +from flashinfer.utils import get_gpu_memory_bandwidth + +# Add the project root to Python path to import test helpers +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from tests.test_helpers.rope_reference import RotaryEmbedding + + +def benchmark_config( + config_name, + num_tokens, + batch_size=4, + page_size=16, + enable_pdl=False, + single_run=False, +): + """Benchmark a specific attention configuration with paged KV cache append.""" + input_dtype = torch.bfloat16 + device = "cuda" + quant_dtype = torch.float8_e4m3fn + + # Configuration-specific parameters + if config_name == "mla": + # MLA: DeepSeek-style multi-latent attention + num_qo_heads, num_kv_heads = 128, 1 + rope_dim, no_rope_dim = 64, 512 + elif config_name == "gqa": + # GQA: Grouped-query attention (e.g., Llama-style) + num_qo_heads, num_kv_heads = 32, 8 + rope_dim, no_rope_dim = 64, 64 + elif config_name == "mha": + # MHA: Standard multi-head attention + num_qo_heads, num_kv_heads = 32, 32 + rope_dim, no_rope_dim = 64, 64 + else: + raise ValueError(f"Unknown config: {config_name}") + + head_dim = rope_dim + no_rope_dim + + # Create input tensors + if config_name == "mla": + # MLA: 2D K tensors (shared) + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn(num_tokens, rope_dim, dtype=input_dtype, device=device) + k_nope = torch.randn(num_tokens, no_rope_dim, dtype=input_dtype, device=device) + v = None + else: + # GQA/MHA: 3D K/V tensors + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + k_rope = torch.randn( + num_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope = torch.randn( + num_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + v = torch.randn( + num_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # Create RoPE reference for cos/sin cache (ensure it covers this run) + max_seq_len = int(num_tokens) + rope_ref = RotaryEmbedding( + head_size=head_dim, + rotary_dim=rope_dim, + max_position_embeddings=max_seq_len, + base=10000, + is_neox_style=False, + dtype=input_dtype, + device=device, + ) + pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Build paged metadata (single request with all tokens) + kv_append_length = torch.tensor( + [num_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length, dim=0), + ] + ) + num_pages_per_req = torch.tensor( + [(num_tokens + page_size - 1) // page_size] + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(num_pages_per_req, dim=0), + ] + ) + kv_page_indices = torch.arange( + kv_page_indptr[-1].item(), dtype=torch.int32, device=device + ) + kv_last_page_len = torch.tensor( + [num_tokens % page_size if num_tokens % page_size != 0 else page_size] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + + # Get batch_indices and positions + seq_lens = flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size) + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, seq_lens, num_tokens + ) + + # Allocate caches + max_pages = kv_page_indptr[-1].item() + + if config_name == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + paged_kv_cache = (ckv_cache, kpe_cache) + else: + # GQA/MHA: use NHD layout + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + paged_kv_cache = (k_cache, v_cache) + + run_idx = 0 + + def execute(): + if single_run: + import torch.cuda.nvtx as nvtx + + nvtx.range_push("rope_append") + nonlocal run_idx + run_idx += 1 + + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope=q_rope, + k_rope=k_rope, + q_nope=q_nope, + k_nope=k_nope, + v=v, + cos_sin_cache=rope_ref.cos_sin_cache, + pos_ids=pos_ids, + paged_kv_cache=paged_kv_cache, + kv_indices=kv_page_indices, + kv_indptr=kv_page_indptr, + batch_indices=batch_indices, + positions=positions, + page_size=page_size, + kv_layout="NHD" if config_name != "mla" else "NHD", + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + if single_run: + # Ensure kernels complete inside the NVTX range for ncu filtering + torch.cuda.synchronize() + nvtx.range_pop() + + if single_run: + execute() + return None, None, None, None, None + measurements = bench_gpu_time_with_cudagraph(execute) + + # Calculate I/O bytes + # Inputs: q_rope, k_rope, q_nope, k_nope, v (if not MLA), cos_sin_cache, pos_ids + io_bytes = ( + q_rope.numel() * q_rope.element_size() + + k_rope.numel() * k_rope.element_size() + + q_nope.numel() * q_nope.element_size() + + k_nope.numel() * k_nope.element_size() + + rope_ref.cos_sin_cache.numel() * rope_ref.cos_sin_cache.element_size() + + pos_ids.numel() * pos_ids.element_size() + ) + + if v is not None: + io_bytes += v.numel() * v.element_size() + + # Outputs: q_rope_out, q_nope_out (FP8), cache writes (FP8) + io_bytes += ( + q_rope.numel() * torch.finfo(quant_dtype).bits // 8 + + q_nope.numel() * torch.finfo(quant_dtype).bits // 8 + ) + + if config_name == "mla": + # MLA writes to ckv_cache and kpe_cache + io_bytes += ( + num_tokens * no_rope_dim * torch.finfo(quant_dtype).bits // 8 + + num_tokens * rope_dim * torch.finfo(quant_dtype).bits // 8 + ) + else: + # GQA/MHA writes to k_cache and v_cache + io_bytes += ( + num_tokens * num_kv_heads * head_dim * torch.finfo(quant_dtype).bits // 8 + + num_tokens * num_kv_heads * head_dim * torch.finfo(quant_dtype).bits // 8 + ) + + # Calculate statistics + ms = np.median(measurements) + min_ms = np.percentile(measurements, 20) + max_ms = np.percentile(measurements, 80) + + # Calculate bandwidth in GB/s + bandwidth_gb_s = io_bytes / ms / 1e6 + + # Calculate TFLOPs (FP operations) + # RoPE: 6 FLOPs per dimension pair (2 muls + 1 sub for real, 2 muls + 1 add for imag) + # For Q: num_tokens * num_qo_heads * (rope_dim/2) pairs * 6 FLOPs + # For K: depends on architecture + q_flops = num_tokens * num_qo_heads * (rope_dim / 2) * 6 + + if config_name == "mla": + # MLA: K is 2D (no head dimension) + k_flops = num_tokens * (rope_dim / 2) * 6 + else: + # GQA/MHA: K is 3D (has head dimension) + k_flops = num_tokens * num_kv_heads * (rope_dim / 2) * 6 + + total_flops = q_flops + k_flops + tflops = ( + total_flops / ms / 1e9 + ) # TFLOPs (operations per ms = operations per second / 1e12) + + return ms, min_ms, max_ms, bandwidth_gb_s, tflops + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ncu-single", action="store_true", help="Run a single execute() for ncu" + ) + parser.add_argument( + "--config", type=str, default="", help="Config name: mla/gqa/mha" + ) + parser.add_argument("--num-tokens", type=int, default=0) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--enable-pdl", type=int, default=0) + args, unknown = parser.parse_known_args() + + if args.ncu_single: + # Minimal single-run for ncu profiling + cfg = args.config or "mla" + ntok = int(args.num_tokens) + pgsz = int(args.page_size) + en_pdl = bool(int(args.enable_pdl)) + # Force a single execution path + benchmark_config(cfg, ntok, page_size=pgsz, enable_pdl=en_pdl, single_run=True) + sys.exit(0) + + # Get GPU information (for display only) + device = torch.device("cuda:0") + gpu_name = torch.cuda.get_device_name(0) + gpu_peak_bandwidth = get_gpu_memory_bandwidth(device) + print(f"\nDetected GPU: {gpu_name}") + print(f"Theoretical Peak Memory Bandwidth: {gpu_peak_bandwidth:.2f} GB/s") + print() + + # Token counts to benchmark + token_counts = [1, 32, 128, 384, 768, 1024, 2048, 4096, 8192] + + # Helper function to print a table for a specific configuration + def print_config_table(config_name, config_desc): + page_size_to_benchmark = 32 + print(f"\n{'=' * 100}") + print(f" {config_name.upper()}: {config_desc}") + print(f"{'=' * 100}") + + print( + f"{'Tokens':<10} {'Time (ms)':<12} {'BW (GB/s)':<12} {'BW% (Peak)':<14} {'TFLOPs':<12}" + ) + print("-" * 70) + for num_tokens in token_counts: + ms, _, _, bw, tflops = benchmark_config( + config_name, num_tokens, page_size=page_size_to_benchmark + ) + bw_pct = (bw / gpu_peak_bandwidth) * 100 + print( + f"{num_tokens:<10} {ms:<12.5f} {bw:<12.2f} {bw_pct:<14.1f} {tflops:<12.3f}" + ) + + # Print tables for each configuration + print_config_table("mla", "128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)") + print_config_table("gqa", "32 Q heads, 8 K heads, 64+64 dims (Llama-style)") + print_config_table("mha", "32 Q heads, 32 K heads, 64+64 dims (Standard)") + + print("\n" + "=" * 100) + print("Configuration details:") + print(" Page size: 32, Batch size: 4") + print(" Token range: 1 (single decode) → 8192 (large prefill)") + print(f" GPU: {gpu_name}") + print(f" Theoretical Peak Memory Bandwidth: {gpu_peak_bandwidth:.2f} GB/s") + print(" BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100") + print("=" * 100) diff --git a/csrc/flashinfer_rope_binding.cu b/csrc/flashinfer_rope_binding.cu index 23124064d8..94809da735 100644 --- a/csrc/flashinfer_rope_binding.cu +++ b/csrc/flashinfer_rope_binding.cu @@ -45,9 +45,19 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope TensorView pos_ids, double quant_scale_q, double quant_scale_kv, bool interleave, bool enable_pdl); +void rope_quantize_append_paged_kv_cache( + TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in, + TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache, + TensorView pos_ids, TensorView k_cache, TensorView v_cache, TensorView ckv_cache, + TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, + TensorView positions, int64_t kv_layout_code, int64_t page_size, double quant_scale_q, + double quant_scale_kv, bool interleave, bool enable_pdl); + TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids, apply_rope_pos_ids); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope_pos_ids, apply_llama31_rope_pos_ids); TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids_cos_sin_cache, apply_rope_pos_ids_cos_sin_cache); TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize_append_paged_kv_cache, + rope_quantize_append_paged_kv_cache); diff --git a/csrc/rope.cu b/csrc/rope.cu index 78cdcad405..40388d9412 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -420,3 +420,198 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope }); }); } + +/*! + * TVM FFI binding for fused RoPE + quantization + paged KV cache append kernel + * + * Validates tensor shapes, dimensions, and data types, then dispatches to the templated + * RopeQuantizeAppendPagedKVCache CUDA kernel implementation. + */ +void rope_quantize_append_paged_kv_cache( + TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in, + TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache, + TensorView pos_ids, + // Paged cache tensors + TensorView k_cache, TensorView v_cache, TensorView ckv_cache, TensorView kpe_cache, + TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, TensorView positions, + int64_t kv_layout_code, int64_t page_size, double quant_scale_q, double quant_scale_kv, + bool interleave, bool enable_pdl) { + // Validate inputs + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_rope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_nope_in); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_out); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_out); + CHECK_INPUT(cos_sin_cache); + CHECK_INPUT(pos_ids); + CHECK_INPUT(kv_indices); + CHECK_INPUT(kv_indptr); + CHECK_INPUT(batch_indices); + CHECK_INPUT(positions); + + // Extract dimensions + uint32_t rope_dim = q_rope_in.size(-1); + uint32_t no_rope_dim = q_nope_in.size(-1); + uint32_t nnz = q_rope_in.size(0); + uint32_t num_qo_heads = q_rope_in.size(1); + + // Validate dimensions + TVM_FFI_ICHECK_EQ(k_rope_in.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(k_nope_in.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(q_rope_out.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(q_nope_out.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_rope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_nope_in.dtype()); + + // Validate input/output dtypes + TVM_FFI_ICHECK(q_rope_in.dtype() == dl_float16 || q_rope_in.dtype() == dl_bfloat16) + << "Input dtype must be float16 or bfloat16"; + TVM_FFI_ICHECK(q_rope_out.dtype() == dl_float8_e4m3fn || q_rope_out.dtype() == dl_float8_e5m2) + << "Output dtype must be float8_e4m3fn or float8_e5m2"; + + // Q tensors are always 3D + CHECK_DIM(3, q_rope_in); + CHECK_DIM(3, q_nope_in); + CHECK_DIM(3, q_rope_out); + CHECK_DIM(3, q_nope_out); + + // Detect architecture based on cache presence/layout (not K dimensionality) + QKVLayout kv_layout = QKVLayout(kv_layout_code); + bool has_mla_caches = (ckv_cache.data_ptr() != nullptr && kpe_cache.data_ptr() != nullptr); + bool has_gqa_caches = (k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr); + bool is_mla = has_mla_caches && !has_gqa_caches; + uint32_t num_kv_heads; + uint32_t batch_size = kv_indptr.size(0) - 1; + + // Require 3D K tensors in both paths; for MLA head dim must be 1 + CHECK_DIM(3, k_rope_in); + CHECK_DIM(3, k_nope_in); + if (is_mla) { + num_kv_heads = 1; + TVM_FFI_ICHECK_EQ(k_rope_in.size(1), 1) << "MLA expects K rope head dim == 1"; + TVM_FFI_ICHECK_EQ(k_nope_in.size(1), 1) << "MLA expects K nope head dim == 1"; + // V can be empty for MLA + TVM_FFI_ICHECK(v_in.data_ptr() == nullptr || v_in.size(0) == 0) + << "MLA should not have V input (or it should be empty)"; + // Validate MLA cache tensors are provided + TVM_FFI_ICHECK(ckv_cache.data_ptr() != nullptr && kpe_cache.data_ptr() != nullptr) + << "MLA requires ckv_cache and kpe_cache"; + CHECK_DIM(3, ckv_cache); // (max_pages, page_size, ckv_dim) + CHECK_DIM(3, kpe_cache); // (max_pages, page_size, kpe_dim) + TVM_FFI_ICHECK_EQ(ckv_cache.size(2), no_rope_dim); + TVM_FFI_ICHECK_EQ(kpe_cache.size(2), rope_dim); + } else { + // GQA/MHA validation + num_kv_heads = k_rope_in.size(1); + TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads); + // V is required for GQA/MHA + CHECK_DIM(3, v_in); + TVM_FFI_ICHECK_EQ(v_in.size(0), nnz); + TVM_FFI_ICHECK_EQ(v_in.size(1), num_kv_heads); + // Validate GQA/MHA cache tensors are provided + TVM_FFI_ICHECK(k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr) + << "GQA/MHA requires k_cache and v_cache"; + // Cache must be 4D + CHECK_DIM(4, k_cache); + CHECK_DIM(4, v_cache); + } + + // Extract Q strides + const uint32_t q_rope_in_stride_n = q_rope_in.stride(0); + const uint32_t q_rope_in_stride_h = q_rope_in.stride(1); + const uint32_t q_nope_in_stride_n = q_nope_in.stride(0); + const uint32_t q_nope_in_stride_h = q_nope_in.stride(1); + const uint32_t q_rope_out_stride_n = q_rope_out.stride(0); + const uint32_t q_rope_out_stride_h = q_rope_out.stride(1); + const uint32_t q_nope_out_stride_n = q_nope_out.stride(0); + const uint32_t q_nope_out_stride_h = q_nope_out.stride(1); + + // Extract K strides + uint32_t k_rope_in_stride, k_nope_in_stride; + uint32_t k_rope_in_stride_h, k_nope_in_stride_h; + uint32_t v_in_stride = 0, v_in_stride_h = 0; + + k_rope_in_stride = k_rope_in.stride(0); + k_nope_in_stride = k_nope_in.stride(0); + k_rope_in_stride_h = k_rope_in.stride(1); + k_nope_in_stride_h = k_nope_in.stride(1); + if (!is_mla) { + v_in_stride = v_in.stride(0); + v_in_stride_h = v_in.stride(1); + } + + cudaSetDevice(q_rope_in.device().device_id); + const cudaStream_t stream = get_stream(q_rope_in.device()); + + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] { + cudaError_t status; + + if (is_mla) { + // MLA: Construct paged_kv_mla_t struct + auto ckv_strides = ckv_cache.strides(); + auto kpe_strides = kpe_cache.strides(); + + paged_kv_mla_t paged_kv_mla( + page_size, no_rope_dim, rope_dim, batch_size, + static_cast(ckv_cache.data_ptr()), ckv_strides.data(), + static_cast(kpe_cache.data_ptr()), kpe_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedMLACache( + static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv_mla, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim, + q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, + q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, + k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave, + enable_pdl, stream); + + } else { + // GQA/MHA: Construct paged_kv_t struct + auto k_strides = k_cache.strides(); + auto v_strides = v_cache.strides(); + uint32_t head_dim = rope_dim + no_rope_dim; + + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, kv_layout, + static_cast(k_cache.data_ptr()), + static_cast(v_cache.data_ptr()), k_strides.data(), + static_cast(kv_indices.data_ptr()), + static_cast(kv_indptr.data_ptr()), + nullptr // last_page_len not needed for this kernel + ); + + status = RopeQuantizeAppendPagedKVCache( + static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), + static_cast(v_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), paged_kv, + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, + no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, + q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, + q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, + k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv, + interleave, enable_pdl, stream); + } + + TVM_FFI_ICHECK(status == cudaSuccess) + << "RopeQuantizeAppendPagedKVCache failed with error code " << cudaGetErrorString(status); + return true; + }); + }); +} diff --git a/flashinfer/rope.py b/flashinfer/rope.py index 7884c439be..dea6995bcf 100644 --- a/flashinfer/rope.py +++ b/flashinfer/rope.py @@ -226,6 +226,105 @@ def _fake_rope_quantize( pass +@register_custom_op( + "flashinfer::rope_quantize_append_paged_kv_cache", + mutates_args=( + "q_rope_out", + "q_nope_out", + "k_cache", + "v_cache", + "ckv_cache", + "kpe_cache", + ), +) +def _rope_quantize_fp8_append_paged_kv_cache( + q_rope_in: torch.Tensor, + k_rope_in: torch.Tensor, + q_nope_in: torch.Tensor, + k_nope_in: torch.Tensor, + v_in: torch.Tensor, + q_rope_out: torch.Tensor, + q_nope_out: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + ckv_cache: torch.Tensor, + kpe_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + kv_layout_code: int, + page_size: int, + quant_scale_q: float, + quant_scale_kv: float, + interleave: bool, + enable_pdl: bool, +) -> None: + r"""Custom operator that routes to the CUDA kernel implementation. + + Fuses RoPE application, FP8 quantization, and paged KV cache append into a single kernel. + + Converts is_neox parameter to interleave format and dispatches to the underlying + CUDA kernel via the JIT-compiled module. + """ + get_rope_module().rope_quantize_append_paged_kv_cache( + q_rope_in, + k_rope_in, + q_nope_in, + k_nope_in, + v_in, + q_rope_out, + q_nope_out, + cos_sin_cache, + pos_ids, + k_cache, + v_cache, + ckv_cache, + kpe_cache, + kv_indices, + kv_indptr, + batch_indices, + positions, + kv_layout_code, + page_size, + quant_scale_q, + quant_scale_kv, + interleave, + enable_pdl, + ) + + +@register_fake_op("flashinfer::rope_quantize_append_paged_kv_cache") +def _fake_rope_quantize_fp8_append_paged_kv_cache( + q_rope_in: torch.Tensor, + k_rope_in: torch.Tensor, + q_nope_in: torch.Tensor, + k_nope_in: torch.Tensor, + v_in: torch.Tensor, + q_rope_out: torch.Tensor, + q_nope_out: torch.Tensor, + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + ckv_cache: torch.Tensor, + kpe_cache: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + kv_layout_code: int, + page_size: int, + quant_scale_q: float, + quant_scale_kv: float, + interleave: bool, + enable_pdl: bool, +) -> None: + pass + + @register_custom_op( "flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope") ) @@ -1186,8 +1285,8 @@ def mla_rope_quantize_fp8( def rope_quantize_fp8( q_rope: torch.Tensor, k_rope: torch.Tensor, - q_nope: torch.Tensor, - k_nope: torch.Tensor, + q_nope: Optional[torch.Tensor], + k_nope: Optional[torch.Tensor], cos_sin_cache: torch.Tensor, pos_ids: torch.Tensor, is_neox: bool = True, @@ -1214,12 +1313,12 @@ def rope_quantize_fp8( k_rope : torch.Tensor Key tensor (rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, rope_dim)``. For MLA: ``(nnz, rope_dim)``. Must be float16 or bfloat16. - q_nope : torch.Tensor + q_nope : Optional[torch.Tensor] Query tensor (non-rotary dimensions), shape: ``(nnz, num_qo_heads, no_rope_dim)``. - Must be float16 or bfloat16. - k_nope : torch.Tensor + If ``None``, treated as zero-dim: a size-0 tensor will be created internally. + k_nope : Optional[torch.Tensor] Key tensor (non-rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, no_rope_dim)``. - For MLA: ``(nnz, no_rope_dim)``. Must be float16 or bfloat16. + For MLA: ``(nnz, no_rope_dim)``. If ``None``, treated as zero-dim and created internally. cos_sin_cache : torch.Tensor Precomputed cosine and sine values, shape: ``(max_seq_len, rope_dim)``. First half contains cosine values, second half contains sine values. Must be float32. @@ -1254,6 +1353,23 @@ def rope_quantize_fp8( if cos_sin_cache.dtype != torch.float32: raise ValueError("cos_sin_cache should be float32") + # Allow None for nope tensors and normalize to size-0 tensors with correct shapes + nnz = q_rope.shape[0] + num_qo_heads = q_rope.shape[1] + is_mla = k_rope.ndim == 2 + num_kv_heads = 1 if is_mla else k_rope.shape[1] + if q_nope is None: + q_nope = torch.empty( + nnz, num_qo_heads, 0, dtype=q_rope.dtype, device=q_rope.device + ) + if k_nope is None: + if is_mla: + k_nope = torch.empty(nnz, 0, dtype=k_rope.dtype, device=k_rope.device) + else: + k_nope = torch.empty( + nnz, num_kv_heads, 0, dtype=k_rope.dtype, device=k_rope.device + ) + # Infer quantize_dtype from output tensors or default to float8_e4m3fn if quantize_dtype is None: for out in (q_rope_out, k_rope_out, q_nope_out, k_nope_out): @@ -1303,3 +1419,239 @@ def rope_quantize_fp8( ) return q_rope_out, k_rope_out, q_nope_out, k_nope_out + + +def rope_quantize_fp8_append_paged_kv_cache( + q_rope: torch.Tensor, + k_rope: torch.Tensor, + q_nope: Optional[torch.Tensor], + k_nope: Optional[torch.Tensor], + v: Optional[torch.Tensor], + cos_sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + paged_kv_cache: Tuple[torch.Tensor, torch.Tensor], + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + batch_indices: torch.Tensor, + positions: torch.Tensor, + is_neox: bool = True, + quantize_dtype: Optional[torch.dtype] = None, + quant_scale_q: float = 1.0, + quant_scale_kv: float = 1.0, + page_size: int = 16, + kv_layout: str = "NHD", + q_rope_out: Optional[torch.Tensor] = None, + q_nope_out: Optional[torch.Tensor] = None, + enable_pdl: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Apply RoPE (Rotary Positional Embeddings), quantize to FP8, and append K/V to paged cache. + + This fused function applies RoPE to query/key (Q/K) rotary dimension tensors, quantizes all Q/K tensors + (and V for GQA/MHA) to FP8 format, and directly appends the quantized K/V to a paged KV cache. + It returns quantized Q tensors for use in attention computation. Supports MLA, GQA, and MHA + architectures with automatic detection based on input tensor shapes. + + Parameters + ---------- + q_rope : torch.Tensor + Query tensor (rotary dimensions), shape: ``(nnz, num_qo_heads, rope_dim)``. + Must be float16 or bfloat16. + k_rope : torch.Tensor + Key tensor (rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, rope_dim)``. + For MLA: ``(nnz, rope_dim)``. Must be float16 or bfloat16. + q_nope : torch.Tensor + Query tensor (non-rotary dimensions), shape: ``(nnz, num_qo_heads, no_rope_dim)``. + Must be float16 or bfloat16. + k_nope : torch.Tensor + Key tensor (non-rotary dimensions). For GQA/MHA: ``(nnz, num_kv_heads, no_rope_dim)``. + For MLA: ``(nnz, no_rope_dim)``. Must be float16 or bfloat16. + v : Optional[torch.Tensor] + Value tensor for GQA/MHA: ``(nnz, num_kv_heads, head_dim)``. Must be float16 or bfloat16. + For MLA: pass ``None`` (MLA does not use separate V; K non-RoPE acts as compressed KV). + cos_sin_cache : torch.Tensor + Precomputed cosine and sine values, shape: ``(max_seq_len, rope_dim)``. + First half contains cosine values, second half contains sine values. Must be float32. + pos_ids : torch.Tensor + Position indices for each token, shape: ``(nnz,)``. + paged_kv_cache : Tuple[torch.Tensor, torch.Tensor] + For MLA: ``(ckv_cache, kpe_cache)`` where: + - ckv_cache: ``(max_pages, page_size, no_rope_dim)`` in FP8 + - kpe_cache: ``(max_pages, page_size, rope_dim)`` in FP8 + For GQA/MHA: ``(k_cache, v_cache)`` where: + - k_cache: ``(max_pages, page_size, num_kv_heads, head_dim)`` or + ``(max_pages, num_kv_heads, page_size, head_dim)`` depending on layout, in FP8 + - v_cache: same shape as k_cache, in FP8 + kv_indices : torch.Tensor + Page indices mapping, shape: ``(total_pages,)``. Typically ``torch.arange(total_pages)``. + kv_indptr : torch.Tensor + Page indptr array for each request, shape: ``(batch_size + 1,)``. + ``kv_indptr[i]`` is the starting page index for request ``i``. + batch_indices : torch.Tensor + Batch index for each token, shape: ``(nnz,)``. Maps each token to its request. + positions : torch.Tensor + Position within each request's sequence for each token, shape: ``(nnz,)``. + is_neox : bool + RoPE layout style. If ``True`` (default), use non-interleaved layout (first/second half). + If ``False``, use interleaved layout (even/odd dimensions). + quantize_dtype : Optional[torch.dtype] + Target quantization dtype. If ``None``, inferred from output tensors or defaults to + ``torch.float8_e4m3fn``. Must be ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. + quant_scale_q : float + Quantization scaling factor for query tensors, default: ``1.0``. + quant_scale_kv : float + Quantization scaling factor for key/value tensors, default: ``1.0``. + page_size : int + Number of entries per page in the paged cache, default: ``16``. + kv_layout : str + Cache memory layout for GQA/MHA. Options: ``"NHD"`` (page, seq, head, dim) or + ``"HND"`` (page, head, seq, dim). Default: ``"NHD"``. Ignored for MLA. + q_rope_out : Optional[torch.Tensor] + Pre-allocated output tensor for quantized query (rotary). If ``None``, allocated automatically. + q_nope_out : Optional[torch.Tensor] + Pre-allocated output tensor for quantized query (non-rotary). If ``None``, allocated automatically. + enable_pdl : bool + Whether to enable PDL (Programmatic Dependent Launch). Default: ``False``. + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + Quantized query tensors: (q_rope_out, q_nope_out). + K/V are written directly to the paged cache and not returned. + + Notes + ----- + - Architecture detection: Automatically distinguishes MLA (2D K tensors) from GQA/MHA (3D K tensors). + - MLA writes K-RoPE to ``kpe_cache`` and K-noRoPE to ``ckv_cache``; V is not used. + - GQA/MHA writes full K (RoPE+noRoPE) to ``k_cache`` and V to ``v_cache``. + - The ``batch_indices`` and ``positions`` tensors are typically obtained from + ``flashinfer.get_batch_indices_positions()``. + - Cache tensors must already be allocated in the target FP8 dtype. + """ + if cos_sin_cache.dtype != torch.float32: + raise ValueError("cos_sin_cache should be float32") + + # Detect architecture + is_mla = k_rope.ndim == 2 + + # Allow None for nope tensors and normalize to size-0 tensors with correct shapes + nnz = q_rope.shape[0] + num_qo_heads = q_rope.shape[1] + if q_nope is None: + q_nope = torch.empty( + nnz, num_qo_heads, 0, dtype=q_rope.dtype, device=q_rope.device + ) + if k_nope is None: + if is_mla: + k_nope = torch.empty(nnz, 0, dtype=k_rope.dtype, device=k_rope.device) + else: + num_kv_heads = k_rope.shape[1] + k_nope = torch.empty( + nnz, num_kv_heads, 0, dtype=k_rope.dtype, device=k_rope.device + ) + + # Infer quantize_dtype from output tensors or default + if quantize_dtype is None: + if q_rope_out is not None: + quantize_dtype = q_rope_out.dtype + elif q_nope_out is not None: + quantize_dtype = q_nope_out.dtype + else: + quantize_dtype = torch.float8_e4m3fn + + # Allocate Q output tensors if not provided + if q_rope_out is None: + q_rope_out = torch.empty_like(q_rope, dtype=quantize_dtype) + if q_nope_out is None: + q_nope_out = torch.empty_like(q_nope, dtype=quantize_dtype) + + # Handle MLA normalization and V (create empty dummy tensor, not used) + if is_mla: + # Normalize MLA K tensors to 3D (nnz, 1, dim) so C++ binding can always assume 3D + if k_rope.ndim == 2: + k_rope = k_rope.unsqueeze(1) + if k_nope.ndim == 2: + k_nope = k_nope.unsqueeze(1) + if v is None: + v = torch.empty(0, dtype=q_rope.dtype, device=q_rope.device) + else: + raise ValueError("MLA should not have V input (pass None)") + + # Unpack and validate cache tensors + if len(paged_kv_cache) != 2: + raise ValueError("paged_kv_cache must be a tuple of 2 tensors") + + cache_0, cache_1 = paged_kv_cache + + if is_mla: + # MLA: Expect (ckv_cache, kpe_cache) + ckv_cache = cache_0 + kpe_cache = cache_1 + if ckv_cache.dtype != quantize_dtype or kpe_cache.dtype != quantize_dtype: + raise ValueError( + f"MLA cache dtype mismatch: expected {quantize_dtype}, " + f"got ckv={ckv_cache.dtype}, kpe={kpe_cache.dtype}" + ) + if ckv_cache.ndim != 3 or kpe_cache.ndim != 3: + raise ValueError( + f"MLA cache must be 3D: (max_pages, page_size, dim), " + f"got ckv={ckv_cache.ndim}D, kpe={kpe_cache.ndim}D" + ) + # Create dummy tensors for GQA/MHA cache (not used) + k_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + v_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + else: + # GQA/MHA: Expect (k_cache, v_cache) + k_cache = cache_0 + v_cache = cache_1 + # Validate V input is provided for GQA/MHA + if v is None: + raise ValueError( + "GQA/MHA expects a V tensor, but got None. " + "Only MLA uses None for V (compressed KV representation)." + ) + if k_cache.dtype != quantize_dtype or v_cache.dtype != quantize_dtype: + raise ValueError( + f"GQA/MHA cache dtype mismatch: expected {quantize_dtype}, " + f"got k={k_cache.dtype}, v={v_cache.dtype}" + ) + if k_cache.ndim != 4 or v_cache.ndim != 4: + raise ValueError( + f"GQA/MHA cache must be 4D, got k={k_cache.ndim}D, v={v_cache.ndim}D" + ) + # Create dummy tensors for MLA cache (not used) + ckv_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + kpe_cache = torch.empty(0, dtype=quantize_dtype, device=q_rope.device) + + # Import TensorLayout enum + from .utils import TensorLayout + + kv_layout_code = TensorLayout[kv_layout].value + + # Call custom op + _rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + v, + q_rope_out, + q_nope_out, + cos_sin_cache, + pos_ids, + k_cache, + v_cache, + ckv_cache, + kpe_cache, + kv_indices, + kv_indptr, + batch_indices, + positions, + kv_layout_code, + page_size, + quant_scale_q, + quant_scale_kv, + not is_neox, # interleave + enable_pdl, + ) + + return q_rope_out, q_nope_out diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 936d08380c..5c4c28642d 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -21,6 +21,7 @@ import torch import torch.version +import pynvml from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version @@ -254,6 +255,46 @@ def get_compute_capability(device: torch.device) -> Tuple[int, int]: return torch.cuda.get_device_capability(device.index) +@functools.cache +def get_gpu_memory_bandwidth(device: torch.device) -> float: + """ + Get GPU memory bandwidth in GB/s for the specified CUDA device. + + Args: + device: torch.device object, e.g., torch.device('cuda:0') + + Returns: + float: GPU memory bandwidth (GB/s) + + Raises: + ValueError: If device is not a CUDA device + """ + # Convert to torch.device object if string is passed + if isinstance(device, str): + device = torch.device(device) + + # Check if it's a CUDA device + if device.type != "cuda": + raise ValueError(f"Device must be a CUDA device, got {device}") + + # Get device index + device_index = device.index if device.index is not None else 0 + + # Use pynvml to get bandwidth + pynvml.nvmlInit() + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle) + mem_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_MEM) + + # Calculate theoretical peak bandwidth (GB/s) + bandwidth = (mem_clock * bus_width * 2) / 8 / 1000 + + return bandwidth + finally: + pynvml.nvmlShutdown() + + def _check_cached_qkv_data_type( q: torch.Tensor, k: torch.Tensor, dtype_q: torch.dtype, dtype_kv: torch.dtype ) -> None: diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 7547a06090..7901b71e22 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -20,14 +20,40 @@ #include #include #include +#include #include "layout.cuh" #include "math.cuh" +#include "page.cuh" #include "utils.cuh" #include "vec_dtypes.cuh" namespace flashinfer { +struct RopeQuantizeAppendPagedKVCacheParams { + uint32_t nnz; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t rope_dim; + uint32_t no_rope_dim; + size_t q_rope_in_stride_n; + size_t q_rope_in_stride_h; + size_t q_nope_in_stride_n; + size_t q_nope_in_stride_h; + size_t q_rope_out_stride_n; + size_t q_rope_out_stride_h; + size_t q_nope_out_stride_n; + size_t q_nope_out_stride_h; + size_t k_rope_in_stride; + size_t k_rope_in_stride_h; + size_t k_nope_in_stride; + size_t k_nope_in_stride_h; + size_t v_in_stride; + size_t v_in_stride_h; + float quant_scale_q; + float quant_scale_kv; +}; + /*! * \brief An enumeration class that defines different modes for applying RoPE * (Rotary Positional Embeddings). @@ -384,7 +410,7 @@ __global__ void RopeQuantizeKernel( // 2. if not interleave // - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)] // - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)] - if ((tx * vec_size < rope_dim) and (by < k_rope_end)) { + if ((tx * vec_size < rope_dim) && (by < k_rope_end)) { int sin_offset = rope_dim / 2; int vec_idx; if constexpr (interleave) { @@ -717,34 +743,237 @@ __global__ void BatchQKApplyRotaryKernel( } } -#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ - if (interleave) { \ - const bool INTERLEAVE = true; \ - __VA_ARGS__ \ - } else { \ - const bool INTERLEAVE = false; \ - __VA_ARGS__ \ - } +/*! + * \brief Unified CUDA kernel to apply RoPE, quantize to FP8, and append to paged cache. + * + * Templated on CacheT to support both GQA/MHA (paged_kv_t) and MLA (paged_kv_mla_t). + * Cache-only behaviors are selected with constexpr on the CacheT. + */ +template +__global__ void RopeQuantizeAppendPagedKVCacheKernel( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, + QuantType* q_rope_out, QuantType* q_nope_out, CacheT paged_kv_like, + IdType* __restrict__ batch_indices, IdType* __restrict__ positions, + float* __restrict__ cos_sin_cache, IdType* __restrict__ pos_ids, + const RopeQuantizeAppendPagedKVCacheParams params) { +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + uint32_t by = blockIdx.y; + uint32_t bdy = blockDim.y; + + // Local aliases for params for readability + const uint32_t nnz = params.nnz; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t rope_dim = params.rope_dim; + const uint32_t no_rope_dim = params.no_rope_dim; + const size_t q_rope_in_stride_n = params.q_rope_in_stride_n; + const size_t q_rope_in_stride_h = params.q_rope_in_stride_h; + const size_t q_nope_in_stride_n = params.q_nope_in_stride_n; + const size_t q_nope_in_stride_h = params.q_nope_in_stride_h; + const size_t q_rope_out_stride_n = params.q_rope_out_stride_n; + const size_t q_rope_out_stride_h = params.q_rope_out_stride_h; + const size_t q_nope_out_stride_n = params.q_nope_out_stride_n; + const size_t q_nope_out_stride_h = params.q_nope_out_stride_h; + const size_t k_rope_in_stride = params.k_rope_in_stride; + const size_t k_rope_in_stride_h = params.k_rope_in_stride_h; + const size_t k_nope_in_stride = params.k_nope_in_stride; + const size_t k_nope_in_stride_h = params.k_nope_in_stride_h; + const size_t v_in_stride = params.v_in_stride; + const size_t v_in_stride_h = params.v_in_stride_h; + const float quant_scale_q = params.quant_scale_q; + const float quant_scale_kv = params.quant_scale_kv; + + // Calculate flexible boundaries for block allocation + uint32_t rope_chunk_size = rope_dim; + uint32_t rope_chunks = (rope_dim + rope_chunk_size - 1) / rope_chunk_size; + uint32_t no_rope_chunks = (no_rope_dim + rope_chunk_size - 1) / rope_chunk_size; + + uint32_t q_rope_end = num_qo_heads * rope_chunks; + // For MLA, num_kv_heads is effectively 1 + uint32_t k_rope_end = q_rope_end + num_kv_heads * rope_chunks; + uint32_t k_nope_end = k_rope_end + num_kv_heads * no_rope_chunks; + + // Deduce MLA vs GQA/MHA from CacheT + constexpr bool IS_MLA = std::is_same>::value; + + vec_t cos, sin; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + + // Compute page location for this token + uint32_t page_iter, entry_idx; + paged_kv_like.page_size.divmod( + paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx], + page_iter, entry_idx); + + const int half_rope_dim = rope_dim / 2; + // Load cos/sin for RoPE processing blocks only + if ((tx * vec_size < rope_dim) && (by < k_rope_end)) { + int sin_offset = rope_dim / 2; + int vec_idx; + if constexpr (interleave) { + vec_idx = (tx * vec_size) / 2; // Force integer division + } else { + vec_idx = (tx * vec_size) % half_rope_dim; + } + cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx); + sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx)); + } + + if (by < q_rope_end) { + // ============ Q RoPE processing ============ + uint32_t q_head_idx = by / rope_chunks; + uint32_t rope_chunk_idx = by % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + + DType* q_rope_in_ptr = + q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n, + q_rope_in_stride_h); + QuantType* q_rope_out_ptr = + q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n, + q_rope_out_stride_h); + + vec_t q_rope_vec; + if constexpr (interleave) { + q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + q_rope_in_ptr, cos, sin, rope_dim); + } else { + q_rope_vec = vec_apply_llama_rope_cos_sin(q_rope_in_ptr, cos, sin, rope_dim); + } +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + q_rope_vec[i] = q_rope_vec[i] * quant_scale_q; + } + q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size); + + } else if (by < k_rope_end) { + // ============ K RoPE processing & Cache Append ============ + uint32_t k_head_idx = (by - q_rope_end) / rope_chunks; + uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks; + uint32_t elem_offset = rope_chunk_idx * rope_chunk_size; + + DType* k_rope_in_ptr; + if constexpr (IS_MLA) { + // MLA: 2D K + k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset; + } else { + // GQA/MHA: 3D K + k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_rope_in_stride, k_rope_in_stride_h); + } + + vec_t k_rope_vec; + if constexpr (interleave) { + k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half( + k_rope_in_ptr, cos, sin, rope_dim); + } else { + k_rope_vec = vec_apply_llama_rope_cos_sin(k_rope_in_ptr, cos, sin, rope_dim); + } +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv; + } + + if constexpr (IS_MLA) { + QuantType* kpe_ptr = + paged_kv_like.get_kpe_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); + k_rope_vec.cast_store(kpe_ptr); + } else { + QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size); + k_rope_vec.cast_store(k_ptr); + } + + } else if (by < k_nope_end) { + // ============ K Non-RoPE processing & Cache Append ============ + uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + + DType* k_nope_in_ptr; + if constexpr (IS_MLA) { + k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset; + } else { + k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset, + k_nope_in_stride, k_nope_in_stride_h); + } + + vec_t k_nope_vec; + k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; + } -#define DISPATCH_ROPE_DIM(rope_dim, vec_size, ...) \ - if (rope_dim == 16) { \ - constexpr uint32_t bdx = 16 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 32) { \ - constexpr uint32_t bdx = 32 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 64) { \ - constexpr uint32_t bdx = 64 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 128) { \ - constexpr uint32_t bdx = 128 / vec_size; \ - __VA_ARGS__ \ - } else if (rope_dim == 256) { \ - constexpr uint32_t bdx = 256 / vec_size; \ - __VA_ARGS__ \ - } else { \ - FLASHINFER_ERROR("Unsupported rope_dim. Supported values: 16, 32, 64, 128, 256"); \ + if constexpr (IS_MLA) { + QuantType* ckv_ptr = + paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); + k_nope_vec.cast_store(ckv_ptr); + } else { + QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, + rope_dim + elem_offset + tx * vec_size); + k_nope_vec.cast_store(k_ptr); + } + + } else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) { + // ============ V processing & Cache Append (GQA/MHA only) ============ + if constexpr (!IS_MLA) { + uint32_t kv_head_idx = by - k_nope_end; + DType* v_in_ptr = + v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h); + // Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size + uint32_t head_dim_total = rope_dim + no_rope_dim; + uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size; +#pragma unroll 1 + for (uint32_t j = 0; j < v_chunks; ++j) { + uint32_t v_elem_offset = j * rope_chunk_size; + if (v_elem_offset + tx * vec_size < head_dim_total) { + vec_t v_vec; + v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + v_vec[i] = v_vec[i] * quant_scale_kv; + } + QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, + v_elem_offset + tx * vec_size); + v_vec.cast_store(v_ptr); + } + } + } + + } else { + // ============ Q Non-RoPE processing ============ + // MLA has no V section, so Q-nope starts immediately after K-nope. + // GQA/MHA has a V section of length num_kv_heads blocks. + uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads); + uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks; + uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks; + uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; + + DType* q_nope_in_ptr = + q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n, + q_nope_in_stride_h); + QuantType* q_nope_out_ptr = + q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n, + q_nope_out_stride_h); + + vec_t q_nope_vec; + q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; + } + q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); + } } +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} template cudaError_t RopeQuantize( @@ -763,11 +992,11 @@ cudaError_t RopeQuantize( FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - constexpr uint32_t vec_size = 32 / sizeof(DType); - // Use nested macros for runtime->compile-time dispatch for required constexpr values - DISPATCH_ROPE_DIM(rope_dim, vec_size, { + DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t bdx = ROPE_DIM / vec_size; uint32_t num_threads = 128U; uint32_t bdy = num_threads / bdx; uint32_t nblks_x = (nnz + bdy - 1) / bdy; @@ -838,6 +1067,185 @@ cudaError_t RopeQuantize( return cudaSuccess; } +/*! + * \brief Host function to apply RoPE, quantize to FP8, and append K/V to paged cache (GQA/MHA) + */ +template +cudaError_t RopeQuantizeAppendPagedKVCache( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, DType* v_in, + QuantType* q_rope_out, QuantType* q_nope_out, paged_kv_t paged_kv, + IdType* batch_indices, IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rope_dim, uint32_t no_rope_dim, + size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, size_t q_nope_in_stride_n, + size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, size_t q_rope_out_stride_h, + size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, size_t k_rope_in_stride, + size_t k_rope_in_stride_h, size_t k_nope_in_stride, size_t k_nope_in_stride_h, + size_t v_in_stride, size_t v_in_stride_h, float quant_scale_q, float quant_scale_kv, + bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { + DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t bdx = ROPE_DIM / vec_size; + uint32_t num_threads = 128U; + uint32_t bdy = num_threads / bdx; + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunks = 1; + uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; + + // GQA/MHA: Q rope + K rope + K nope + V + Q nope + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_kv_heads + + num_qo_heads * no_rope_chunks; + + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + auto kernel = RopeQuantizeAppendPagedKVCacheKernel>; + RopeQuantizeAppendPagedKVCacheParams params; + params.nnz = nnz; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.rope_dim = rope_dim; + params.no_rope_dim = no_rope_dim; + params.q_rope_in_stride_n = q_rope_in_stride_n; + params.q_rope_in_stride_h = q_rope_in_stride_h; + params.q_nope_in_stride_n = q_nope_in_stride_n; + params.q_nope_in_stride_h = q_nope_in_stride_h; + params.q_rope_out_stride_n = q_rope_out_stride_n; + params.q_rope_out_stride_h = q_rope_out_stride_h; + params.q_nope_out_stride_n = q_nope_out_stride_n; + params.q_nope_out_stride_h = q_nope_out_stride_h; + params.k_rope_in_stride = k_rope_in_stride; + params.k_rope_in_stride_h = k_rope_in_stride_h; + params.k_nope_in_stride = k_nope_in_stride; + params.k_nope_in_stride_h = k_nope_in_stride_h; + params.v_in_stride = v_in_stride; + params.v_in_stride_h = v_in_stride_h; + params.quant_scale_q = quant_scale_q; + params.quant_scale_kv = quant_scale_kv; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, + // inputs + q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in, + // q outputs + q_rope_out, q_nope_out, + // cache + indices + paged_kv, batch_indices, positions, + // rope tables + cos_sin_cache, pos_ids, + // params + params)); + }); + }); + + return cudaSuccess; +} + +/*! + * \brief Host function to apply RoPE, quantize to FP8, and append to MLA paged cache + */ +template +cudaError_t RopeQuantizeAppendPagedMLACache( + DType* q_rope_in, DType* k_rope_in, DType* q_nope_in, DType* k_nope_in, QuantType* q_rope_out, + QuantType* q_nope_out, paged_kv_mla_t paged_kv_mla, IdType* batch_indices, + IdType* positions, float* cos_sin_cache, IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, + uint32_t rope_dim, uint32_t no_rope_dim, size_t q_rope_in_stride_n, size_t q_rope_in_stride_h, + size_t q_nope_in_stride_n, size_t q_nope_in_stride_h, size_t q_rope_out_stride_n, + size_t q_rope_out_stride_h, size_t q_nope_out_stride_n, size_t q_nope_out_stride_h, + size_t k_rope_in_stride, size_t k_nope_in_stride, float quant_scale_q, float quant_scale_kv, + bool interleave, bool enable_pdl = false, cudaStream_t stream = nullptr) { + DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, { + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + constexpr uint32_t vec_size = 32 / sizeof(DType); + constexpr uint32_t bdx = ROPE_DIM / vec_size; + uint32_t num_threads = 128U; + uint32_t bdy = num_threads / bdx; + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + uint32_t rope_chunks = 1; + uint32_t no_rope_chunks = (no_rope_dim + rope_dim - 1) / rope_dim; + + // MLA: Q rope + K rope + K nope + Q nope (no V) + constexpr uint32_t num_kv_heads = 1; + uint32_t total_blocks_y = num_qo_heads * rope_chunks + num_kv_heads * rope_chunks + + num_kv_heads * no_rope_chunks + num_qo_heads * no_rope_chunks; + + dim3 nblks(nblks_x, total_blocks_y); + dim3 nthrs(bdx, bdy); + + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = enable_pdl ? 1 : 0; + cudaLaunchConfig_t config; + config.gridDim = nblks; + config.blockDim = nthrs; + config.stream = stream; + config.dynamicSmemBytes = 0; + config.attrs = attribute; + config.numAttrs = 1; + + auto kernel = + RopeQuantizeAppendPagedKVCacheKernel>; + // For MLA: pass v_in as nullptr, num_kv_heads=1, duplicate 2D K strides for head strides, and + // 0 V strides + DType* v_in_nullptr = nullptr; + uint32_t num_kv_heads_1 = 1; + size_t k_rope_in_stride_h_dup = k_rope_in_stride; + size_t k_nope_in_stride_h_dup = k_nope_in_stride; + size_t v_in_stride_zero = 0, v_in_stride_h_zero = 0; + RopeQuantizeAppendPagedKVCacheParams params; + params.nnz = nnz; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = 1u; + params.rope_dim = rope_dim; + params.no_rope_dim = no_rope_dim; + params.q_rope_in_stride_n = q_rope_in_stride_n; + params.q_rope_in_stride_h = q_rope_in_stride_h; + params.q_nope_in_stride_n = q_nope_in_stride_n; + params.q_nope_in_stride_h = q_nope_in_stride_h; + params.q_rope_out_stride_n = q_rope_out_stride_n; + params.q_rope_out_stride_h = q_rope_out_stride_h; + params.q_nope_out_stride_n = q_nope_out_stride_n; + params.q_nope_out_stride_h = q_nope_out_stride_h; + params.k_rope_in_stride = k_rope_in_stride; + params.k_rope_in_stride_h = k_rope_in_stride_h_dup; + params.k_nope_in_stride = k_nope_in_stride; + params.k_nope_in_stride_h = k_nope_in_stride_h_dup; + params.v_in_stride = 0; + params.v_in_stride_h = 0; + params.quant_scale_q = quant_scale_q; + params.quant_scale_kv = quant_scale_kv; + + FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, + // inputs + q_rope_in, k_rope_in, q_nope_in, k_nope_in, + v_in_nullptr, + // q outputs + q_rope_out, q_nope_out, + // cache + indices + paged_kv_mla, batch_indices, positions, + // rope tables + cos_sin_cache, pos_ids, + // params + params)); + }); + }); + + return cudaSuccess; +} + template cudaError_t BatchQKApplyRotaryPosIdsCosSinCache( DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_sin_cache, IdType* pos_ids, diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 5b26d7beaf..0471bd1081 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -201,6 +201,52 @@ } \ } +// convert interleave to compile-time constant +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + constexpr bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_ROPE_DIM(rope_dim, ROPE_DIM, ...) \ + switch (rope_dim) { \ + case 16: { \ + constexpr uint32_t ROPE_DIM = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 32: { \ + constexpr uint32_t ROPE_DIM = 32; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: { \ + constexpr uint32_t ROPE_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr uint32_t ROPE_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 256: { \ + constexpr uint32_t ROPE_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported ROPE_DIM: " << rope_dim; \ + err_msg << ". Supported values: 16, 32, 64, 128, 256"; \ + err_msg << " in DISPATCH_ROPE_DIM"; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + #define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ switch (pos_encoding_mode) { \ case PosEncodingMode::kNone: { \ diff --git a/tests/attention/test_rope.py b/tests/attention/test_rope.py index da59223a4f..8e694088e5 100644 --- a/tests/attention/test_rope.py +++ b/tests/attention/test_rope.py @@ -394,6 +394,10 @@ def test_generalized_rope_quantize( ): """Test generalized rope + quantization for MLA, GQA, and MHA architectures.""" device = "cuda:0" + # Fixed seed for reproducibility across tests + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) total_dim = rope_dim + no_rope_dim # Create input tensors based on attention type @@ -481,6 +485,893 @@ def test_generalized_rope_quantize( ) +@pytest.mark.parametrize( + "attention_type,num_qo_heads,num_kv_heads,rope_dim,no_rope_dim", + [ + # MLA: Multiple Q heads, single shared K/V head + ("mla", 128, 1, 64, 512), + ("mla", 64, 1, 128, 256), + ("mla", 128, 1, 64, 128), # Explicit DeepSeek R1 MLA config case + ("mla", 32, 1, 32, 96), + # GQA: Multiple Q heads, fewer K/V heads (grouped) + ("gqa", 32, 8, 64, 64), + ("gqa", 64, 16, 128, 128), + ("gqa", 24, 6, 32, 96), + ("gqa", 32, 8, 128, 0), # Llama3 8B standard config + ("gqa", 64, 8, 128, 0), # Llama3 70B standard config + ("gqa", 64, 8, 64, 0), # (plausible) GPT-OSS config + # MHA: Equal Q and K/V heads + ("mha", 32, 32, 64, 64), + ("mha", 16, 16, 128, 128), + ("mha", 8, 8, 32, 96), + ], +) +@pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) +@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("enable_pdl", [True, False]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +@pytest.mark.parametrize("page_size", [16, 32]) +def test_generalized_rope_quantize_append_kv_cache( + attention_type, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + num_tokens, + input_dtype, + quant_dtype, + enable_pdl, + kv_layout, + page_size, +): + device = "cuda:0" + # Fixed seed for reproducibility + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + head_dim = rope_dim + no_rope_dim + batch_size = 4 + + # Build inputs following the same pattern used elsewhere + if attention_type == "mla": + # Q: (N, Hq, *), K: 2D (N, *) + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = ( + None + if no_rope_dim == 0 + else torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + ) + k_rope = torch.randn(num_tokens, rope_dim, dtype=input_dtype, device=device) + k_nope = ( + None + if no_rope_dim == 0 + else torch.randn(num_tokens, no_rope_dim, dtype=input_dtype, device=device) + ) + v = None + else: + # GQA/MHA: K/V are 3D + q_rope = torch.randn( + num_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope = ( + None + if no_rope_dim == 0 + else torch.randn( + num_tokens, num_qo_heads, no_rope_dim, dtype=input_dtype, device=device + ) + ) + k_rope = torch.randn( + num_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope = ( + None + if no_rope_dim == 0 + else torch.randn( + num_tokens, num_kv_heads, no_rope_dim, dtype=input_dtype, device=device + ) + ) + v = torch.randn( + num_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + # Cos/sin and positions + max_seq_len = 4096 + rope_ref = FlashInferRotaryEmbedding( + head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device + ) + pos_ids = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Build paged metadata + kv_append_length = torch.tensor( + [num_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length, dim=0), + ] + ) + num_pages_per_req = torch.tensor( + [(num_tokens + page_size - 1) // page_size] + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indptr = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(num_pages_per_req, dim=0), + ] + ) + kv_page_indices = torch.arange( + kv_page_indptr[-1].item(), dtype=torch.int32, device=device + ) + kv_last_page_len = torch.tensor( + [num_tokens % page_size if num_tokens % page_size != 0 else page_size] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + # Allocate caches sized by required pages + max_pages = kv_page_indptr[-1].item() + + # Get batch_indices and positions + seq_lens = flashinfer.get_seq_lens(kv_page_indptr, kv_last_page_len, page_size) + batch_indices, positions = flashinfer.get_batch_indices_positions( + kv_append_indptr, seq_lens, num_tokens + ) + + # Fused call + cache allocation + if attention_type == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + q_rope_out_fused, q_nope_out_fused = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + None, + rope_ref.cos_sin_cache, + pos_ids, + (ckv_cache, kpe_cache), + kv_page_indices, + kv_page_indptr, + batch_indices, + positions, + page_size=page_size, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + else: + # Allocate cache based on layout + if kv_layout == "NHD": + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + else: # HND + k_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + q_rope_out_fused, q_nope_out_fused = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope, + k_rope, + q_nope, + k_nope, + v, + rope_ref.cos_sin_cache, + pos_ids, + (k_cache, v_cache), + kv_page_indices, + kv_page_indptr, + batch_indices, + positions, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + # Compute reference output (handle None for no_rope_dim == 0) + q_in = q_rope if q_nope is None else torch.cat([q_rope, q_nope], dim=-1) + k_in = k_rope if k_nope is None else torch.cat([k_rope, k_nope], dim=-1) + q_out_f16_ref, k_out_f16_ref = rope_ref.forward_native(pos_ids, q_in, k_in) + q_out_f8_ref, k_out_f8_ref = map( + lambda x: x.to(quant_dtype), + (q_out_f16_ref, k_out_f16_ref), + ) + + # Fused vs Pytorch reference Q checks + torch.testing.assert_close( + q_out_f8_ref[..., :rope_dim].float(), + q_rope_out_fused.float(), + rtol=2e-1, + atol=1e-2, + ) + torch.testing.assert_close( + q_out_f8_ref[..., rope_dim:].float(), + q_nope_out_fused.float(), + rtol=2e-1, + atol=1e-2, + ) + + # expect 1-ULP differences between FP8 device rounding and PyTorch .to(fp8) + if quant_dtype == torch.float8_e4m3fn: + rtol_val, atol_val = 0.25, 0.5 + else: # quant_dtype == torch.float8_e5m2: + rtol_val, atol_val = 0.25, 1.0 + + # if MLA: check ckv_cache, kpe_cache + if attention_type == "mla": + # Split K reference + k_rope_ref = k_out_f8_ref[..., :rope_dim] + k_nope_ref = k_out_f8_ref[..., rope_dim:] + + ckv_ref = torch.zeros_like(ckv_cache) + kpe_ref = torch.zeros_like(kpe_cache) + + for i in range(num_tokens): + b = batch_indices[i].item() + pos = positions[i].item() + page_iter = (kv_page_indptr[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices[page_iter].item() + ckv_ref[page_idx, entry_idx, :] = k_nope_ref[i] + kpe_ref[page_idx, entry_idx, :] = k_rope_ref[i] + + torch.testing.assert_close( + ckv_cache.float(), ckv_ref.float(), rtol=rtol_val, atol=atol_val + ) + torch.testing.assert_close( + kpe_cache.float(), kpe_ref.float(), rtol=rtol_val, atol=atol_val + ) + + # if GQA/MHA: check k_cache, v_cache + if attention_type == "gqa" or attention_type == "mha": + # K reference + k_ref = torch.zeros_like(k_cache) + for i in range(num_tokens): + b = batch_indices[i].item() + pos = positions[i].item() + page_iter = (kv_page_indptr[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices[page_iter].item() + if kv_layout == "NHD": + k_ref[page_idx, entry_idx, :, :] = k_out_f8_ref[i] # [Hkv, head_dim] + else: # HND + k_ref[page_idx, :, entry_idx, :] = k_out_f8_ref[i] # [Hkv, head_dim] + + torch.testing.assert_close( + k_cache.float(), k_ref.float(), rtol=rtol_val, atol=atol_val + ) + + # V reference (no RoPE on V; same quant scale as KV) + quant_scale_kv = 1.0 # match fused call + v_ref_tokens = (v * quant_scale_kv).to(quant_dtype) + v_ref = torch.zeros_like(v_cache) + for i in range(num_tokens): + b = batch_indices[i].item() + pos = positions[i].item() + page_iter = (kv_page_indptr[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices[page_iter].item() + if kv_layout == "NHD": + v_ref[page_idx, entry_idx, :, :] = v_ref_tokens[i] + else: # HND + v_ref[page_idx, :, entry_idx, :] = v_ref_tokens[i] + + torch.testing.assert_close( + v_cache.float(), v_ref.float(), rtol=rtol_val, atol=atol_val + ) + + +@pytest.mark.parametrize( + "attention_type,num_qo_heads,num_kv_heads,rope_dim,no_rope_dim", + [ + # MLA: Multiple Q heads, single shared K/V head + ("mla", 128, 1, 64, 512), + ("mla", 32, 1, 32, 96), + # GQA: Multiple Q heads, fewer K/V heads (grouped) + ("gqa", 32, 8, 64, 64), + ("gqa", 32, 8, 128, 0), # Llama3 8B standard config + # MHA: Equal Q and K/V heads + ("mha", 32, 32, 64, 64), + ("mha", 16, 16, 128, 128), + ], +) +@pytest.mark.parametrize("num_existing_tokens", [10, 50]) +@pytest.mark.parametrize("num_new_tokens", [1, 8]) +@pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("enable_pdl", [True, False]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +@pytest.mark.parametrize("page_size", [16, 32]) +def test_rope_quantize_fp8_append_paged_kv_cache_decode( + attention_type, + num_qo_heads, + num_kv_heads, + rope_dim, + no_rope_dim, + num_existing_tokens, + num_new_tokens, + input_dtype, + quant_dtype, + enable_pdl, + kv_layout, + page_size, +): + """Test append to non-empty cache (decode/continuation scenario).""" + device = "cuda:0" + torch.manual_seed(42) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + head_dim = rope_dim + no_rope_dim + batch_size = 2 + + # Step 1: Pre-populate cache with existing tokens + if attention_type == "mla": + q_rope_existing = torch.randn( + num_existing_tokens, + num_qo_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + q_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_existing = torch.randn( + num_existing_tokens, rope_dim, dtype=input_dtype, device=device + ) + k_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, no_rope_dim, dtype=input_dtype, device=device + ) + ) + v_existing = None + else: + q_rope_existing = torch.randn( + num_existing_tokens, + num_qo_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + q_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_existing = torch.randn( + num_existing_tokens, + num_kv_heads, + rope_dim, + dtype=input_dtype, + device=device, + ) + k_nope_existing = ( + None + if no_rope_dim == 0 + else torch.randn( + num_existing_tokens, + num_kv_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + v_existing = torch.randn( + num_existing_tokens, + num_kv_heads, + head_dim, + dtype=input_dtype, + device=device, + ) + + # Create RoPE reference + max_seq_len = 4096 + rope_ref = FlashInferRotaryEmbedding( + head_dim, rope_dim, max_seq_len, 10000, False, input_dtype, device + ) + pos_ids_existing = torch.arange( + num_existing_tokens, device=device, dtype=torch.int32 + ) + + # Build metadata for existing tokens (single request for simplicity) + kv_append_length_existing = torch.tensor( + [num_existing_tokens] + [0] * (batch_size - 1), dtype=torch.int32, device=device + ) + kv_append_indptr_existing = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(kv_append_length_existing, dim=0), + ] + ) + num_pages_existing = (num_existing_tokens + page_size - 1) // page_size + kv_page_indptr_existing = torch.tensor( + [0, num_pages_existing] + [num_pages_existing] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indices_existing = torch.arange( + num_pages_existing, dtype=torch.int32, device=device + ) + kv_last_page_len_existing = torch.tensor( + [ + num_existing_tokens % page_size + if num_existing_tokens % page_size != 0 + else page_size + ] + + [0] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + seq_lens_existing = flashinfer.get_seq_lens( + kv_page_indptr_existing, kv_last_page_len_existing, page_size + ) + batch_indices_existing, positions_existing = flashinfer.get_batch_indices_positions( + kv_append_indptr_existing, seq_lens_existing, num_existing_tokens + ) + + # Allocate cache sized for existing + new tokens + total_tokens = num_existing_tokens + num_new_tokens + max_pages = (total_tokens + page_size - 1) // page_size + + if attention_type == "mla": + ckv_cache = torch.zeros( + max_pages, page_size, no_rope_dim, dtype=quant_dtype, device=device + ) + kpe_cache = torch.zeros( + max_pages, page_size, rope_dim, dtype=quant_dtype, device=device + ) + # Pre-populate with existing tokens + _, _ = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_existing, + k_rope_existing, + q_nope_existing, + k_nope_existing, + None, + rope_ref.cos_sin_cache, + pos_ids_existing, + (ckv_cache, kpe_cache), + kv_page_indices_existing, + kv_page_indptr_existing, + batch_indices_existing, + positions_existing, + page_size=page_size, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + else: + if kv_layout == "NHD": + k_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + page_size, + num_kv_heads, + head_dim, + dtype=quant_dtype, + device=device, + ) + else: # HND + k_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + v_cache = torch.zeros( + max_pages, + num_kv_heads, + page_size, + head_dim, + dtype=quant_dtype, + device=device, + ) + # Pre-populate with existing tokens + _, _ = flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_existing, + k_rope_existing, + q_nope_existing, + k_nope_existing, + v_existing, + rope_ref.cos_sin_cache, + pos_ids_existing, + (k_cache, v_cache), + kv_page_indices_existing, + kv_page_indptr_existing, + batch_indices_existing, + positions_existing, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + + # Step 2: Append new tokens to the pre-populated cache + if attention_type == "mla": + q_rope_new = torch.randn( + num_new_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_new = torch.randn( + num_new_tokens, rope_dim, dtype=input_dtype, device=device + ) + k_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, no_rope_dim, dtype=input_dtype, device=device + ) + ) + v_new = None + else: + q_rope_new = torch.randn( + num_new_tokens, num_qo_heads, rope_dim, dtype=input_dtype, device=device + ) + q_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_qo_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + k_rope_new = torch.randn( + num_new_tokens, num_kv_heads, rope_dim, dtype=input_dtype, device=device + ) + k_nope_new = ( + None + if no_rope_dim == 0 + else torch.randn( + num_new_tokens, + num_kv_heads, + no_rope_dim, + dtype=input_dtype, + device=device, + ) + ) + v_new = torch.randn( + num_new_tokens, num_kv_heads, head_dim, dtype=input_dtype, device=device + ) + + pos_ids_new = torch.arange( + num_existing_tokens, + num_existing_tokens + num_new_tokens, + device=device, + dtype=torch.int32, + ) + + # Build metadata for new tokens (continue appending to first request) + num_pages_new_needed = (total_tokens + page_size - 1) // page_size + kv_page_indptr_new = torch.tensor( + [0, num_pages_new_needed] + [num_pages_new_needed] * (batch_size - 1), + dtype=torch.int32, + device=device, + ) + kv_page_indices_new = torch.arange( + num_pages_new_needed, dtype=torch.int32, device=device + ) + # For continuation, positions start at num_existing_tokens + batch_indices_new = torch.zeros(num_new_tokens, device=device, dtype=torch.int32) + positions_new = torch.arange( + num_existing_tokens, + num_existing_tokens + num_new_tokens, + device=device, + dtype=torch.int32, + ) + + # Snapshot existing cache for later comparison + if attention_type == "mla": + ckv_cache_before = ckv_cache.clone() + kpe_cache_before = kpe_cache.clone() + else: + k_cache_before = k_cache.clone() + v_cache_before = v_cache.clone() + + # Append new tokens + if attention_type == "mla": + q_rope_out_new, q_nope_out_new = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_new, + k_rope_new, + q_nope_new, + k_nope_new, + None, + rope_ref.cos_sin_cache, + pos_ids_new, + (ckv_cache, kpe_cache), + kv_page_indices_new, + kv_page_indptr_new, + batch_indices_new, + positions_new, + page_size=page_size, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + else: + q_rope_out_new, q_nope_out_new = ( + flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache( + q_rope_new, + k_rope_new, + q_nope_new, + k_nope_new, + v_new, + rope_ref.cos_sin_cache, + pos_ids_new, + (k_cache, v_cache), + kv_page_indices_new, + kv_page_indptr_new, + batch_indices_new, + positions_new, + page_size=page_size, + kv_layout=kv_layout, + quantize_dtype=quant_dtype, + quant_scale_q=1.0, + quant_scale_kv=1.0, + is_neox=False, + enable_pdl=enable_pdl, + ) + ) + + # Verify Q outputs for new tokens (handle None for no_rope_dim == 0) + q_in_new = ( + q_rope_new + if q_nope_new is None + else torch.cat([q_rope_new, q_nope_new], dim=-1) + ) + k_in_new = ( + k_rope_new + if k_nope_new is None + else torch.cat([k_rope_new, k_nope_new], dim=-1) + ) + q_out_f16_ref_new, k_out_f16_ref_new = rope_ref.forward_native( + pos_ids_new, q_in_new, k_in_new + ) + q_out_f8_ref_new = q_out_f16_ref_new.to(quant_dtype) + k_out_f8_ref_new = k_out_f16_ref_new.to(quant_dtype) + + torch.testing.assert_close( + q_out_f8_ref_new[..., :rope_dim].float(), + q_rope_out_new.float(), + rtol=2e-1, + atol=1e-2, + ) + torch.testing.assert_close( + q_out_f8_ref_new[..., rope_dim:].float(), + q_nope_out_new.float(), + rtol=2e-1, + atol=1e-2, + ) + + # FP8 tolerances + if quant_dtype == torch.float8_e4m3fn: + rtol_val, atol_val = 0.25, 0.5 + else: + rtol_val, atol_val = 0.25, 1.0 + + # Verify existing cache entries remain unchanged + if attention_type == "mla": + # Check that entries before num_existing_tokens are unchanged + for i in range(num_existing_tokens): + b = batch_indices_existing[i].item() + pos = positions_existing[i].item() + page_iter = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) // page_size + entry_idx = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) % page_size + page_idx = kv_page_indices_existing[page_iter].item() + torch.testing.assert_close( + ckv_cache[page_idx, entry_idx, :].float(), + ckv_cache_before[page_idx, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing CKV cache entry {i} was modified", + ) + torch.testing.assert_close( + kpe_cache[page_idx, entry_idx, :].float(), + kpe_cache_before[page_idx, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing KPE cache entry {i} was modified", + ) + else: + for i in range(num_existing_tokens): + b = batch_indices_existing[i].item() + pos = positions_existing[i].item() + page_iter = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) // page_size + entry_idx = ( + kv_page_indptr_existing[b].item() * page_size + pos + ) % page_size + page_idx = kv_page_indices_existing[page_iter].item() + if kv_layout == "NHD": + torch.testing.assert_close( + k_cache[page_idx, entry_idx, :, :].float(), + k_cache_before[page_idx, entry_idx, :, :].float(), + rtol=0, + atol=0, + msg=f"Existing K cache entry {i} was modified", + ) + torch.testing.assert_close( + v_cache[page_idx, entry_idx, :, :].float(), + v_cache_before[page_idx, entry_idx, :, :].float(), + rtol=0, + atol=0, + msg=f"Existing V cache entry {i} was modified", + ) + else: # HND + torch.testing.assert_close( + k_cache[page_idx, :, entry_idx, :].float(), + k_cache_before[page_idx, :, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing K cache entry {i} was modified", + ) + torch.testing.assert_close( + v_cache[page_idx, :, entry_idx, :].float(), + v_cache_before[page_idx, :, entry_idx, :].float(), + rtol=0, + atol=0, + msg=f"Existing V cache entry {i} was modified", + ) + + # Verify new cache entries are correct + if attention_type == "mla": + k_rope_ref_new = k_out_f8_ref_new[..., :rope_dim] + k_nope_ref_new = k_out_f8_ref_new[..., rope_dim:] + + for i in range(num_new_tokens): + b = batch_indices_new[i].item() + pos = positions_new[i].item() + page_iter = (kv_page_indptr_new[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr_new[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices_new[page_iter].item() + torch.testing.assert_close( + ckv_cache[page_idx, entry_idx, :].float(), + k_nope_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + torch.testing.assert_close( + kpe_cache[page_idx, entry_idx, :].float(), + k_rope_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + else: + quant_scale_kv = 1.0 + v_ref_tokens_new = (v_new * quant_scale_kv).to(quant_dtype) + + for i in range(num_new_tokens): + b = batch_indices_new[i].item() + pos = positions_new[i].item() + page_iter = (kv_page_indptr_new[b].item() * page_size + pos) // page_size + entry_idx = (kv_page_indptr_new[b].item() * page_size + pos) % page_size + page_idx = kv_page_indices_new[page_iter].item() + if kv_layout == "NHD": + torch.testing.assert_close( + k_cache[page_idx, entry_idx, :, :].float(), + k_out_f8_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + torch.testing.assert_close( + v_cache[page_idx, entry_idx, :, :].float(), + v_ref_tokens_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + else: # HND + torch.testing.assert_close( + k_cache[page_idx, :, entry_idx, :].float(), + k_out_f8_ref_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + torch.testing.assert_close( + v_cache[page_idx, :, entry_idx, :].float(), + v_ref_tokens_new[i].float(), + rtol=rtol_val, + atol=atol_val, + ) + + @pytest.mark.parametrize("num_tokens", [1, 19, 128, 199, 899, 2047]) @pytest.mark.parametrize("input_dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @@ -492,6 +1383,10 @@ def test_mla_rope_quantize( enable_pdl, ): device = "cuda:0" + # Fixed seed for reproducibility across tests + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) num_qo_heads = 128 q_in = torch.randn(num_tokens, num_qo_heads, 576, dtype=input_dtype, device=device) k_in = torch.randn(num_tokens, 576, dtype=input_dtype, device=device)