Skip to content
Merged
10 changes: 10 additions & 0 deletions csrc/flashinfer_rope_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,19 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope
TensorView pos_ids, double quant_scale_q, double quant_scale_kv, bool interleave,
bool enable_pdl);

void rope_quantize_append_paged_kv_cache(
TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in,
TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache,
TensorView pos_ids, TensorView k_cache, TensorView v_cache, TensorView ckv_cache,
TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices,
TensorView positions, int64_t kv_layout_code, int64_t page_size, double quant_scale_q,
double quant_scale_kv, bool interleave, bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope, apply_rope);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope, apply_llama31_rope);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids, apply_rope_pos_ids);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_llama31_rope_pos_ids, apply_llama31_rope_pos_ids);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(apply_rope_pos_ids_cos_sin_cache, apply_rope_pos_ids_cos_sin_cache);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize, rope_quantize);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(rope_quantize_append_paged_kv_cache,
rope_quantize_append_paged_kv_cache);
200 changes: 200 additions & 0 deletions csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,203 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope
});
});
}

/*!
* TVM FFI binding for fused RoPE + quantization + paged KV cache append kernel
*
* Validates tensor shapes, dimensions, and data types, then dispatches to the templated
* RopeQuantizeAppendPagedKVCache CUDA kernel implementation.
*/
void rope_quantize_append_paged_kv_cache(
TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope_in, TensorView k_nope_in,
TensorView v_in, TensorView q_rope_out, TensorView q_nope_out, TensorView cos_sin_cache,
TensorView pos_ids,
// Paged cache tensors
TensorView k_cache, TensorView v_cache, TensorView ckv_cache, TensorView kpe_cache,
TensorView kv_indices, TensorView kv_indptr, TensorView batch_indices, TensorView positions,
int64_t kv_layout_code, int64_t page_size, double quant_scale_q, double quant_scale_kv,
bool interleave, bool enable_pdl) {
// Validate inputs
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_in);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_rope_in);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_in);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_nope_in);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_rope_out);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_nope_out);
CHECK_INPUT(cos_sin_cache);
CHECK_INPUT(pos_ids);
CHECK_INPUT(kv_indices);
CHECK_INPUT(kv_indptr);
CHECK_INPUT(batch_indices);
CHECK_INPUT(positions);

// Extract dimensions
uint32_t rope_dim = q_rope_in.size(-1);
uint32_t no_rope_dim = q_nope_in.size(-1);
uint32_t nnz = q_rope_in.size(0);
uint32_t num_qo_heads = q_rope_in.size(1);

// Validate dimensions
TVM_FFI_ICHECK_EQ(k_rope_in.size(-1), rope_dim);
TVM_FFI_ICHECK_EQ(k_nope_in.size(-1), no_rope_dim);
TVM_FFI_ICHECK_EQ(q_rope_out.size(-1), rope_dim);
TVM_FFI_ICHECK_EQ(q_nope_out.size(-1), no_rope_dim);
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_rope_in.dtype());
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_in.dtype());
TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_nope_in.dtype());

// Validate input/output dtypes
TVM_FFI_ICHECK(q_rope_in.dtype() == dl_float16 || q_rope_in.dtype() == dl_bfloat16)
<< "Input dtype must be float16 or bfloat16";
TVM_FFI_ICHECK(q_rope_out.dtype() == dl_float8_e4m3fn || q_rope_out.dtype() == dl_float8_e5m2)
<< "Output dtype must be float8_e4m3fn or float8_e5m2";

// Q tensors are always 3D
CHECK_DIM(3, q_rope_in);
CHECK_DIM(3, q_nope_in);
CHECK_DIM(3, q_rope_out);
CHECK_DIM(3, q_nope_out);

// Detect architecture: MLA (2D K) vs GQA/MHA (3D K)
bool is_mla = (k_rope_in.ndim() == 2);
uint32_t num_kv_heads;
uint32_t batch_size = kv_indptr.size(0) - 1;
QKVLayout kv_layout = QKVLayout(kv_layout_code);

if (is_mla) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we unify the two branches?

In C++ side, we assume K tensor is 3D.
At python side, if we found K tensor is 2D, we unsqueeze its dimension 1.

// MLA: K tensors are 2D
CHECK_DIM(2, k_rope_in);
CHECK_DIM(2, k_nope_in);
num_kv_heads = 1;
// V can be empty for MLA
TVM_FFI_ICHECK(v_in.data_ptr() == nullptr || v_in.size(0) == 0)
<< "MLA should not have V input (or it should be empty)";
// Validate MLA cache tensors are provided
TVM_FFI_ICHECK(ckv_cache.data_ptr() != nullptr && kpe_cache.data_ptr() != nullptr)
<< "MLA requires ckv_cache and kpe_cache";
CHECK_DIM(3, ckv_cache); // (max_pages, page_size, ckv_dim)
CHECK_DIM(3, kpe_cache); // (max_pages, page_size, kpe_dim)
TVM_FFI_ICHECK_EQ(ckv_cache.size(2), no_rope_dim);
TVM_FFI_ICHECK_EQ(kpe_cache.size(2), rope_dim);
} else {
// GQA/MHA: K tensors are 3D
CHECK_DIM(3, k_rope_in);
CHECK_DIM(3, k_nope_in);
num_kv_heads = k_rope_in.size(1);
TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads);
// V is required for GQA/MHA
CHECK_DIM(3, v_in);
TVM_FFI_ICHECK_EQ(v_in.size(0), nnz);
TVM_FFI_ICHECK_EQ(v_in.size(1), num_kv_heads);
// Validate GQA/MHA cache tensors are provided
TVM_FFI_ICHECK(k_cache.data_ptr() != nullptr && v_cache.data_ptr() != nullptr)
<< "GQA/MHA requires k_cache and v_cache";
// Cache must be 4D
CHECK_DIM(4, k_cache);
CHECK_DIM(4, v_cache);
}

// Extract Q strides
const uint32_t q_rope_in_stride_n = q_rope_in.stride(0);
const uint32_t q_rope_in_stride_h = q_rope_in.stride(1);
const uint32_t q_nope_in_stride_n = q_nope_in.stride(0);
const uint32_t q_nope_in_stride_h = q_nope_in.stride(1);
const uint32_t q_rope_out_stride_n = q_rope_out.stride(0);
const uint32_t q_rope_out_stride_h = q_rope_out.stride(1);
const uint32_t q_nope_out_stride_n = q_nope_out.stride(0);
const uint32_t q_nope_out_stride_h = q_nope_out.stride(1);

// Extract K strides (architecture dependent)
uint32_t k_rope_in_stride, k_nope_in_stride;
uint32_t k_rope_in_stride_h, k_nope_in_stride_h;
uint32_t v_in_stride = 0, v_in_stride_h = 0;

if (is_mla) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

// MLA: 2D K tensors
k_rope_in_stride = k_rope_in.stride(0);
k_nope_in_stride = k_nope_in.stride(0);
k_rope_in_stride_h = k_rope_in_stride; // Same as batch stride for 2D
k_nope_in_stride_h = k_nope_in_stride;
} else {
// GQA/MHA: 3D K tensors
k_rope_in_stride = k_rope_in.stride(0);
k_rope_in_stride_h = k_rope_in.stride(1);
k_nope_in_stride = k_nope_in.stride(0);
k_nope_in_stride_h = k_nope_in.stride(1);
v_in_stride = v_in.stride(0);
v_in_stride_h = v_in.stride(1);
}

cudaSetDevice(q_rope_in.device().device_id);
const cudaStream_t stream = get_stream(q_rope_in.device());

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] {
cudaError_t status;

if (is_mla) {
// MLA: Construct paged_kv_mla_t struct
auto ckv_strides = ckv_cache.strides();
auto kpe_strides = kpe_cache.strides();

paged_kv_mla_t<c_quant_type, int32_t> paged_kv_mla(
page_size, no_rope_dim, rope_dim, batch_size,
static_cast<c_quant_type*>(ckv_cache.data_ptr()), ckv_strides.data(),
static_cast<c_quant_type*>(kpe_cache.data_ptr()), kpe_strides.data(),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);

status = RopeQuantizeAppendPagedMLACache(
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv_mla,
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, rope_dim, no_rope_dim,
q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h,
q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h,
k_rope_in_stride, k_nope_in_stride, quant_scale_q, quant_scale_kv, interleave,
enable_pdl, stream);

} else {
// GQA/MHA: Construct paged_kv_t struct
auto k_strides = k_cache.strides();
auto v_strides = v_cache.strides();
uint32_t head_dim = rope_dim + no_rope_dim;

paged_kv_t<c_quant_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_quant_type*>(k_cache.data_ptr()),
static_cast<c_quant_type*>(v_cache.data_ptr()), k_strides.data(),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
nullptr // last_page_len not needed for this kernel
);

status = RopeQuantizeAppendPagedKVCache(
static_cast<c_type*>(q_rope_in.data_ptr()), static_cast<c_type*>(k_rope_in.data_ptr()),
static_cast<c_type*>(q_nope_in.data_ptr()), static_cast<c_type*>(k_nope_in.data_ptr()),
static_cast<c_type*>(v_in.data_ptr()),
static_cast<c_quant_type*>(q_rope_out.data_ptr()),
static_cast<c_quant_type*>(q_nope_out.data_ptr()), paged_kv,
static_cast<int32_t*>(batch_indices.data_ptr()),
static_cast<int32_t*>(positions.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim,
no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n,
q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n,
q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride,
k_nope_in_stride_h, v_in_stride, v_in_stride_h, quant_scale_q, quant_scale_kv,
interleave, enable_pdl, stream);
}

TVM_FFI_ICHECK(status == cudaSuccess)
<< "RopeQuantizeAppendPagedKVCache failed with error code " << cudaGetErrorString(status);
return true;
});
});
}
Loading