Skip to content

Commit 876c386

Browse files
committed
upd
1 parent 7886b7d commit 876c386

File tree

6 files changed

+563
-26
lines changed

6 files changed

+563
-26
lines changed
Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,15 @@
1-
// TODO: Not implemented yet
1+
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
2+
#include "batch_prefill_sm90_config.inc"
3+
4+
namespace flashinfer {
5+
6+
{% for same_scheduler_for_all_heads in ["true", "false"] %}
7+
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
8+
<{{ head_dim_qk }},
9+
{{ mask_mode }},
10+
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
11+
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
12+
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
13+
{% endfor %}
14+
15+
}; // namespace flashinfer

csrc/batch_prefill_fp8_sm90.cu

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
2929
cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched(Params& params, bool enable_pdl,
3030
cudaStream_t stream);
3131

32+
template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
33+
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename Params>
34+
cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched(Params& params, bool enable_pdl,
35+
cudaStream_t stream);
36+
3237
} // namespace flashinfer
3338

3439
using namespace flashinfer;
@@ -78,7 +83,94 @@ void BatchPrefillWithRaggedKVCacheSM90Run(ffi::TensorView float_workspace_buffer
7883
int64_t window_left,
7984
bool enable_pdl // placeholder
8085
ADDITIONAL_FUNC_PARAMS) {
81-
return; // TODO: Implement this function
86+
PrefillPlanSM90Info plan_info;
87+
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));
88+
89+
if (maybe_lse.has_value()) {
90+
const auto& lse = maybe_lse.value();
91+
TVM_FFI_ICHECK_EQ(lse.size(0), q.size(0));
92+
TVM_FFI_ICHECK_EQ(lse.size(1), q.size(1));
93+
}
94+
95+
void* float_buffer_ptr = float_workspace_buffer.data_ptr();
96+
void* int_buffer_ptr = int_workspace_buffer.data_ptr();
97+
98+
int64_t head_dim_qk = q.size(2);
99+
int64_t head_dim_vo = v.size(2);
100+
101+
QKVLayout kv_layout = static_cast<QKVLayout>(layout);
102+
103+
cudaSetDevice(float_workspace_buffer.device().device_id);
104+
const cudaStream_t stream = get_stream(float_workspace_buffer.device());
105+
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
106+
bool use_swa = window_left != -1;
107+
108+
DISPATCH_context(
109+
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW,
110+
USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] {
111+
RaggedParams params;
112+
113+
params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
114+
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
115+
params.v_ptr = static_cast<DTypeKV*>(v.data_ptr());
116+
params.o_ptr = static_cast<DTypeO*>(o.data_ptr());
117+
params.lse_ptr = maybe_lse ? static_cast<float*>(maybe_lse.value().data_ptr()) : nullptr;
118+
params.q_stride_n = q.stride(0);
119+
params.q_stride_h = q.stride(1);
120+
params.o_stride_n = o.stride(0);
121+
params.o_stride_h = o.stride(1);
122+
if (kv_layout == QKVLayout::kNHD) {
123+
params.k_stride_n = k.stride(0);
124+
params.k_stride_h = k.stride(1);
125+
params.v_stride_n = v.stride(0);
126+
params.v_stride_h = v.stride(1);
127+
} else {
128+
params.k_stride_h = k.stride(0);
129+
params.k_stride_n = k.stride(1);
130+
params.v_stride_h = v.stride(0);
131+
params.v_stride_n = v.stride(1);
132+
}
133+
params.nnz_qo = q.size(0);
134+
params.nnz_kv = k.size(0);
135+
params.num_qo_heads = q.size(1);
136+
params.num_kv_heads = k.size(1);
137+
params.group_size = params.num_qo_heads / params.num_kv_heads;
138+
params.window_left = window_left;
139+
params.causal = mask_mode_code == 1;
140+
params.qo_tile_indices =
141+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_tile_indices_offset);
142+
params.qo_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_indptr_offset);
143+
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
144+
params.qo_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_len_offset);
145+
params.kv_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
146+
params.head_indices =
147+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
148+
params.work_indptr =
149+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
150+
params.batch_indices =
151+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);
152+
153+
ADDITIONAL_PARAMS_SETTER
154+
155+
// Not support various head_dim for now
156+
static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same");
157+
// Currently only support same quantization precision
158+
static_assert(std::is_same_v<DTypeQ, DTypeKV>);
159+
160+
bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;
161+
DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
162+
cudaError_t status =
163+
BatchFP8PrefillWithRaggedKVCacheDispatched<HEAD_DIM_QK, MASK_MODE, USE_SLIDING_WINDOW,
164+
SAME_SCHEDULER_FOR_ALL_HEADS,
165+
AttentionVariant>(params, enable_pdl,
166+
stream);
167+
168+
TVM_FFI_ICHECK(status == cudaSuccess)
169+
<< "BatchPrefillWithRaggedKVCacheSM90Run failed with error: "
170+
<< cudaGetErrorString(status);
171+
return true;
172+
});
173+
});
82174
}
83175

84176
void BatchPrefillWithPagedKVCacheSM90Run(
@@ -136,12 +228,18 @@ void BatchPrefillWithPagedKVCacheSM90Run(
136228
params.k_stride_h = paged_k_cache.stride(2);
137229
params.v_stride_n = paged_v_cache.stride(1);
138230
params.v_stride_h = paged_v_cache.stride(2);
231+
// For sparse paged KV cache, store the stride between pages
232+
params.k_page_stride = paged_k_cache.stride(0);
233+
params.v_page_stride = paged_v_cache.stride(0);
139234
} else {
140235
// (num_pages, num_heads, page_size, head_dim)
141236
params.k_stride_h = paged_k_cache.stride(1);
142237
params.k_stride_n = paged_k_cache.stride(2);
143238
params.v_stride_h = paged_v_cache.stride(1);
144239
params.v_stride_n = paged_v_cache.stride(2);
240+
// For sparse paged KV cache, store the stride between pages
241+
params.k_page_stride = paged_k_cache.stride(0);
242+
params.v_page_stride = paged_v_cache.stride(0);
145243
}
146244
params.nnz_qo = q.size(0);
147245
params.num_qo_heads = q.size(1);

flashinfer/prefill.py

Lines changed: 79 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,13 @@ def ragged_run(
413413
rope_scale: float,
414414
rope_theta: float,
415415
token_pos_in_items_len: int,
416+
scale_q: Optional[torch.Tensor] = None,
417+
scale_k: Optional[torch.Tensor] = None,
418+
scale_v: Optional[torch.Tensor] = None,
416419
) -> None:
420+
# Check if FP8 by presence of scale tensors
421+
is_fp8 = scale_q is not None
422+
417423
if backend == "fa2":
418424
ragged_run_func(
419425
float_workspace_buffer,
@@ -439,10 +445,34 @@ def ragged_run(
439445
logits_soft_cap,
440446
sm_scale,
441447
1.0 / rope_scale, # rope_rcp_scale
442-
1.0 / rope_theta, # rope_rcp_theta
448+
1.0 / rope_theta, # rope_rcp_theta,
443449
token_pos_in_items_len,
444450
)
451+
elif is_fp8:
452+
# FA3 FP8: scale_q, scale_k, scale_v, sm_scale
453+
ragged_run_func(
454+
float_workspace_buffer,
455+
int_workspace_buffer,
456+
plan_info_vec,
457+
q,
458+
k,
459+
v,
460+
qo_indptr,
461+
kv_indptr,
462+
o,
463+
maybe_lse,
464+
mask_mode,
465+
layout,
466+
window_left,
467+
enable_pdl,
468+
scale_q,
469+
scale_k,
470+
scale_v,
471+
sm_scale,
472+
)
445473
else:
474+
# FA3 FP16: maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr,
475+
# maybe_max_item_len_ptr, logits_soft_cap, sm_scale, token_pos_in_items_len
446476
ragged_run_func(
447477
float_workspace_buffer,
448478
int_workspace_buffer,
@@ -1533,6 +1563,7 @@ def plan(
15331563
rope_theta: Optional[float] = None,
15341564
q_data_type: Union[str, torch.dtype] = "float16",
15351565
kv_data_type: Optional[Union[str, torch.dtype]] = None,
1566+
o_data_type: Optional[Union[str, torch.dtype]] = None,
15361567
non_blocking: bool = True,
15371568
prefix_len_ptr: Optional[torch.Tensor] = None,
15381569
token_pos_in_items_ptr: Optional[torch.Tensor] = None,
@@ -1617,6 +1648,9 @@ def plan(
16171648
The data type of the query tensor, defaults torch.float16.
16181649
kv_data_type : Optional[Union[str, torch.dtype]]
16191650
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
1651+
o_data_type : Optional[Union[str, torch.dtype]]
1652+
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
1653+
For FP8 inputs, this should typically be set to torch.float16.
16201654
non_blocking : bool
16211655
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
16221656
prefix_len_ptr :Optional[torch.Tensor]
@@ -1668,6 +1702,9 @@ def plan(
16681702
if kv_data_type is None:
16691703
kv_data_type = q_data_type
16701704
kv_data_type = canonicalize_torch_dtype(kv_data_type)
1705+
if o_data_type is None:
1706+
o_data_type = q_data_type
1707+
o_data_type = canonicalize_torch_dtype(o_data_type)
16711708

16721709
if logits_soft_cap is None:
16731710
logits_soft_cap = 0.0
@@ -1798,6 +1835,7 @@ def plan(
17981835

17991836
self._cached_q_data_type = q_data_type
18001837
self._cached_kv_data_type = kv_data_type
1838+
self._cached_o_data_type = o_data_type
18011839

18021840
if self._jit_module is not None:
18031841
self._cached_module = self._jit_module
@@ -1815,7 +1853,7 @@ def plan(
18151853
get_module_args = (
18161854
q_data_type,
18171855
kv_data_type,
1818-
q_data_type,
1856+
o_data_type,
18191857
paged_kv_indptr.dtype,
18201858
head_dim_qk,
18211859
head_dim_vo,
@@ -2052,12 +2090,15 @@ def run(
20522090
)
20532091

20542092
if out is None:
2093+
# Use cached output data type if available (for FP8 attention with FP16 output)
2094+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
20552095
out = torch.empty(
2056-
q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device
2096+
q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device
20572097
)
20582098
else:
2099+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
20592100
check_shape_dtype_device(
2060-
out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out"
2101+
out, q.shape[:-1] + v_cache.shape[-1:], out_dtype, q.device, "out"
20612102
)
20622103

20632104
# Convert NHD layout to HND for trtllm-gen backend
@@ -2126,6 +2167,14 @@ def run(
21262167
if self._jit_module is not None:
21272168
run_args.extend(list(args))
21282169
else:
2170+
# Extract FP8 scale tensors from *args if q is FP8
2171+
fp8_scale_q = None
2172+
fp8_scale_k = None
2173+
fp8_scale_v = None
2174+
if is_float8(q) and len(args) >= 3:
2175+
fp8_scale_q = args[0]
2176+
fp8_scale_k = args[1]
2177+
fp8_scale_v = args[2]
21292178
run_args += [
21302179
self._custom_mask_buf,
21312180
self._mask_indptr_buf,
@@ -2135,9 +2184,9 @@ def run(
21352184
self._max_item_len_ptr,
21362185
logits_soft_cap,
21372186
sm_scale,
2138-
None, # scale_q, not supported yet
2139-
None, # scale_k
2140-
None, # scale_v
2187+
fp8_scale_q,
2188+
fp8_scale_k,
2189+
fp8_scale_v,
21412190
rope_scale,
21422191
rope_theta,
21432192
self._token_pos_in_items_len,
@@ -2466,6 +2515,7 @@ def plan(
24662515
rope_theta: Optional[float] = None,
24672516
q_data_type: Union[str, torch.dtype] = "float16",
24682517
kv_data_type: Optional[Union[str, torch.dtype]] = None,
2518+
o_data_type: Optional[Union[str, torch.dtype]] = None,
24692519
non_blocking: bool = True,
24702520
prefix_len_ptr: Optional[torch.Tensor] = None,
24712521
token_pos_in_items_ptr: Optional[torch.Tensor] = None,
@@ -2540,6 +2590,9 @@ def plan(
25402590
The data type of the query tensor, defaults to torch.float16.
25412591
kv_data_type : Optional[Union[str, torch.dtype]]
25422592
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
2593+
o_data_type : Optional[Union[str, torch.dtype]]
2594+
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
2595+
For FP8 inputs, this should typically be set to torch.float16.
25432596
non_blocking : bool
25442597
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
25452598
prefix_len_ptr :Optional[torch.Tensor]
@@ -2580,6 +2633,9 @@ def plan(
25802633
if kv_data_type is None:
25812634
kv_data_type = q_data_type
25822635
kv_data_type = canonicalize_torch_dtype(kv_data_type)
2636+
if o_data_type is None:
2637+
o_data_type = q_data_type
2638+
o_data_type = canonicalize_torch_dtype(o_data_type)
25832639
if head_dim_vo is None:
25842640
head_dim_vo = head_dim_qk
25852641
if fixed_split_size is None:
@@ -2652,6 +2708,7 @@ def plan(
26522708

26532709
self._cached_q_data_type = q_data_type
26542710
self._cached_kv_data_type = kv_data_type
2711+
self._cached_o_data_type = o_data_type
26552712
kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1]
26562713

26572714
self._prefix_len_ptr = prefix_len_ptr
@@ -2675,7 +2732,7 @@ def plan(
26752732
get_module_args = (
26762733
q_data_type,
26772734
kv_data_type,
2678-
q_data_type,
2735+
o_data_type,
26792736
kv_indptr.dtype,
26802737
head_dim_qk,
26812738
head_dim_vo,
@@ -2862,11 +2919,17 @@ def run(
28622919
)
28632920
if out is None:
28642921
out = torch.empty(
2865-
q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device
2922+
q.shape[:-1] + v.shape[-1:],
2923+
dtype=self._cached_o_data_type,
2924+
device=q.device,
28662925
)
28672926
else:
28682927
check_shape_dtype_device(
2869-
out, q.shape[:-1] + v.shape[-1:], q.dtype, q.device, "out"
2928+
out,
2929+
q.shape[:-1] + v.shape[-1:],
2930+
self._cached_o_data_type,
2931+
q.device,
2932+
"out",
28702933
)
28712934
if self._backend == "cutlass":
28722935
out, lse = fmha_varlen(
@@ -2884,7 +2947,9 @@ def run(
28842947
)
28852948
return (out, lse) if return_lse else out
28862949

2887-
if is_float8(q):
2950+
# Skip FP8->FP16 conversion for FA3 backend with FP8 support
2951+
# The JIT module will handle FP8 natively
2952+
if is_float8(q) and self._backend != "fa3":
28882953
logging.warning(
28892954
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
28902955
" are casted to f16, which could result in performance degradation."
@@ -2933,6 +2998,9 @@ def run(
29332998
rope_theta,
29342999
self._token_pos_in_items_len,
29353000
]
3001+
# For FP8, append scale tensors
3002+
if is_float8(q):
3003+
run_args.extend(list(args)) # scale_q, scale_k, scale_v
29363004

29373005
assert self._cached_module is not None, "cached module is not initialized"
29383006
self._cached_module.ragged_run(*run_args)

0 commit comments

Comments
 (0)