Skip to content

Commit a8d9e6a

Browse files
committed
Enable Hopper FA3 FP8 attention
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
1 parent 12b8ad0 commit a8d9e6a

File tree

5 files changed

+84
-30
lines changed

5 files changed

+84
-30
lines changed

flashinfer/decode.py

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
_get_range_buf,
5656
_unpack_paged_kv_cache,
5757
canonicalize_torch_dtype,
58+
determine_attention_backend,
5859
device_support_pdl,
5960
get_device_sm_count,
6061
is_float8,
@@ -710,7 +711,7 @@ def __init__(
710711
self._jit_module = get_batch_prefill_jit_module(
711712
jit_args[0],
712713
gen_customize_batch_prefill_module(
713-
"fa2", *jit_args
714+
backend, *jit_args
714715
).build_and_load(),
715716
)
716717
else:
@@ -822,6 +823,7 @@ def plan(
822823
logits_soft_cap: Optional[float] = None,
823824
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
824825
kv_data_type: Optional[Union[str, torch.dtype]] = None,
826+
o_data_type: Optional[Union[str, torch.dtype]] = None,
825827
data_type: Optional[Union[str, torch.dtype]] = None,
826828
sm_scale: Optional[float] = None,
827829
rope_scale: Optional[float] = None,
@@ -869,6 +871,9 @@ def plan(
869871
kv_data_type : Optional[Union[str, torch.dtype]]
870872
The data type of the key/value tensor. If None, will be set to
871873
``q_data_type``. Defaults to ``None``.
874+
o_data_type : Optional[Union[str, torch.dtype]]
875+
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
876+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
872877
data_type: Optional[Union[str, torch.dtype]]
873878
The data type of both the query and key/value tensors. Defaults to torch.float16.
874879
data_type is deprecated, please use q_data_type and kv_data_type instead.
@@ -964,6 +969,10 @@ def plan(
964969
if kv_data_type is None:
965970
kv_data_type = q_data_type
966971
kv_data_type = canonicalize_torch_dtype(kv_data_type)
972+
if o_data_type is None:
973+
o_data_type = q_data_type
974+
o_data_type = canonicalize_torch_dtype(o_data_type)
975+
967976
if fixed_split_size is not None and not self.use_tensor_cores:
968977
raise ValueError(
969978
"fixed_split_size is only supported by tensor core decode for now."
@@ -973,6 +982,7 @@ def plan(
973982

974983
self._cached_q_data_type = q_data_type
975984
self._cached_kv_data_type = kv_data_type
985+
self._cached_o_data_type = o_data_type
976986
self._batch_size = batch_size
977987
self._num_qo_heads = num_qo_heads
978988
self._num_kv_heads = num_kv_heads
@@ -1012,7 +1022,7 @@ def plan(
10121022
self._cached_module = get_trtllm_gen_decode_module(
10131023
q_data_type,
10141024
kv_data_type,
1015-
q_data_type,
1025+
o_data_type,
10161026
indptr.dtype,
10171027
head_dim,
10181028
head_dim,
@@ -1027,11 +1037,20 @@ def plan(
10271037
if self._jit_module is not None:
10281038
self._cached_module = self._jit_module
10291039
else:
1040+
if self._backend == "auto":
1041+
self._backend = determine_attention_backend(
1042+
self.device,
1043+
PosEncodingMode[pos_encoding_mode].value,
1044+
False, # use_fp16_qk_reduction
1045+
False, # use_custom_mask
1046+
q_data_type,
1047+
kv_data_type,
1048+
)
10301049
self._cached_module = get_batch_prefill_module(
1031-
"fa2",
1050+
self._backend,
10321051
q_data_type,
10331052
kv_data_type,
1034-
q_data_type,
1053+
o_data_type,
10351054
indptr.dtype,
10361055
head_dim, # head_dim_qk
10371056
head_dim, # head_dim_vo
@@ -1041,7 +1060,7 @@ def plan(
10411060
False, # use_fp16_qk_reduction
10421061
)
10431062

1044-
self._plan_info = self._cached_module.plan(
1063+
args = [
10451064
self._float_workspace_buffer,
10461065
self._int_workspace_buffer,
10471066
self._pin_memory_int_workspace_buffer,
@@ -1058,9 +1077,13 @@ def plan(
10581077
head_dim,
10591078
False, # causal
10601079
window_left,
1061-
fixed_split_size,
1062-
disable_split_kv,
1063-
0, # num_colocated_ctas
1080+
]
1081+
if self._backend == "fa2":
1082+
args.append(fixed_split_size)
1083+
args.append(disable_split_kv)
1084+
args.append(0) # num_colocated_ctas
1085+
self._plan_info = self._cached_module.plan(
1086+
*args,
10641087
)
10651088
else:
10661089
if self._jit_module is not None:
@@ -1069,7 +1092,7 @@ def plan(
10691092
self._cached_module = get_batch_decode_module(
10701093
q_data_type,
10711094
kv_data_type,
1072-
q_data_type,
1095+
o_data_type,
10731096
indptr.dtype,
10741097
head_dim, # head_dim_qk
10751098
head_dim, # head_dim_vo
@@ -1278,9 +1301,13 @@ def run(
12781301
)
12791302

12801303
if out is None:
1281-
out = torch.empty_like(q)
1304+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
1305+
out = torch.empty(
1306+
q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device
1307+
)
12821308
else:
1283-
check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
1309+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
1310+
check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out")
12841311

12851312
if self._backend == "trtllm-gen":
12861313
q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2))
@@ -1308,6 +1335,14 @@ def run(
13081335
if self._jit_module is not None:
13091336
run_args.extend(list(args))
13101337
else:
1338+
# Extract FP8 scale tensors from *args if q is FP8
1339+
fp8_scale_q = None
1340+
fp8_scale_k = None
1341+
fp8_scale_v = None
1342+
if is_float8(q) and len(args) >= 3:
1343+
fp8_scale_q = args[0]
1344+
fp8_scale_k = args[1]
1345+
fp8_scale_v = args[2]
13111346
run_args += [
13121347
None, # packed_custom_mask
13131348
None, # mask_indptr_buf
@@ -1317,9 +1352,9 @@ def run(
13171352
None, # maybe_max_item_len_ptr
13181353
logits_soft_cap,
13191354
sm_scale,
1320-
None, # scale_q, not supported yet
1321-
None, # scale_k
1322-
None, # scale_v
1355+
fp8_scale_q,
1356+
fp8_scale_k,
1357+
fp8_scale_v,
13231358
rope_scale,
13241359
rope_theta,
13251360
0, # token_pos_in_items_len
@@ -1372,7 +1407,7 @@ def run(
13721407
]
13731408

13741409
self._cached_module.run(*run_args)
1375-
if v_scale is not None:
1410+
if v_scale is not None and v_scale != 1.0:
13761411
# TODO(Zihao): fused into kernel
13771412
if is_float8(out):
13781413
out = (out.to(torch.float32) * v_scale).to(out.dtype)
@@ -2921,8 +2956,8 @@ def fast_decode_plan(
29212956
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
29222957

29232958
try:
2924-
# Make sure we pass exactly 16 arguments for tensor core version
2925-
self._plan_info = self._cached_module.plan(
2959+
# Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
2960+
args = [
29262961
self._float_workspace_buffer,
29272962
self._int_workspace_buffer,
29282963
self._pin_memory_int_workspace_buffer,
@@ -2939,9 +2974,13 @@ def fast_decode_plan(
29392974
head_dim,
29402975
False, # causal
29412976
window_left,
2942-
fixed_split_size,
2943-
disable_split_kv,
2944-
0, # num_colocated_ctas
2977+
]
2978+
if self._backend == "fa2":
2979+
args.append(fixed_split_size)
2980+
args.append(disable_split_kv)
2981+
args.append(0) # num_colocated_ctas
2982+
self._plan_info = self._cached_module.plan(
2983+
*args,
29452984
)
29462985
except Exception as e:
29472986
raise RuntimeError(f"Error in standard plan: {e}") from e

flashinfer/jit/attention/modules.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,9 @@ def gen_batch_prefill_module(
974974
# KV-only quant is not influenced by this flag
975975
fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2]
976976

977+
assert backend in ["fa2", "fa3"], f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}"
978+
assert dtype_o not in [torch.float8_e4m3fn, torch.float8_e5m2], "FP8 output is not supported in fa2/fa3 backends yet"
979+
977980
if backend == "fa2":
978981
assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend"
979982
additional_tensor_names = [
@@ -1019,7 +1022,7 @@ def gen_batch_prefill_module(
10191022
variant_name = f"DefaultAttention<{str(use_logits_soft_cap).lower()}>"
10201023
variant_decl = "#include<flashinfer/attention/hopper/variants.cuh>"
10211024
else:
1022-
additional_tensor_names = ["scale_q", "scale_k", "scale_v"]
1025+
additional_tensor_names = ["maybe_scale_q", "maybe_scale_k", "maybe_scale_v"]
10231026
additional_tensor_dtypes = ["float", "float", "float"]
10241027
additional_scalar_names = ["sm_scale"]
10251028
additional_scalar_dtypes = ["double"]

flashinfer/prefill.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,7 @@ def plan(
16501650
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
16511651
o_data_type : Optional[Union[str, torch.dtype]]
16521652
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
1653-
For FP8 inputs, this should typically be set to torch.float16.
1653+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
16541654
non_blocking : bool
16551655
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
16561656
prefix_len_ptr :Optional[torch.Tensor]
@@ -1699,6 +1699,7 @@ def plan(
16991699
The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
17001700
"""
17011701
q_data_type = canonicalize_torch_dtype(q_data_type)
1702+
17021703
if kv_data_type is None:
17031704
kv_data_type = q_data_type
17041705
kv_data_type = canonicalize_torch_dtype(kv_data_type)
@@ -2025,6 +2026,8 @@ def run(
20252026
20262027
*args
20272028
Additional arguments for custom kernels.
2029+
q_scale : Optional[float]
2030+
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
20282031
k_scale : Optional[float]
20292032
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
20302033
v_scale : Optional[float]
@@ -2053,6 +2056,11 @@ def run(
20532056
_check_cached_qkv_data_type(
20542057
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
20552058
)
2059+
o_dtype = self._cached_o_data_type
2060+
if out is not None and out.dtype != o_dtype:
2061+
raise ValueError(
2062+
f"The dtype of out {out.dtype} does not match the o_data_type {o_dtype} specified in plan function."
2063+
)
20562064

20572065
if self._kv_layout == "NHD":
20582066
page_size = k_cache.shape[1]
@@ -2206,7 +2214,7 @@ def run(
22062214

22072215
assert self._cached_module is not None, "cached module is not initialized"
22082216
self._cached_module.paged_run(*run_args)
2209-
if v_scale is not None:
2217+
if v_scale is not None and v_scale != 1.0:
22102218
# TODO(Zihao): fused into kernel
22112219
if is_float8(out):
22122220
out = (out.to(torch.float32) * v_scale).to(out.dtype)
@@ -2592,7 +2600,7 @@ def plan(
25922600
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
25932601
o_data_type : Optional[Union[str, torch.dtype]]
25942602
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
2595-
For FP8 inputs, this should typically be set to torch.float16.
2603+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
25962604
non_blocking : bool
25972605
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
25982606
prefix_len_ptr :Optional[torch.Tensor]

include/flashinfer/attention/hopper/quantization/mainloop_sparse_load.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,8 @@ struct FP8SparseCollectiveMainloop {
212212
IdType const* kv_indices_ptr = mainloop_params.kv_indices + kv_indptr;
213213

214214
// Setup for manual K/V loading with page table
215-
DTypeKV* k_base_ptr = mainloop_params.K_ptr;
216-
DTypeKV* v_base_ptr = mainloop_params.V_ptr;
215+
DTypeKV* k_base_ptr = mainloop_params.K_ptr + kv_head_idx * HEAD_DIM;
216+
DTypeKV* v_base_ptr = mainloop_params.V_ptr + kv_head_idx * HEAD_DIM;
217217
int64_t k_stride_n = mainloop_params.k_stride_n;
218218
int64_t k_page_stride = mainloop_params.k_page_stride;
219219
int64_t v_stride_n = mainloop_params.v_stride_n;

include/flashinfer/attention/hopper/variants.cuh

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,14 @@ struct StandardFP8Attention {
7070
block_coord;
7171
// 448 for e4m3; 57344 for e5m2
7272
p_scale = std::numeric_limits<typename MainloopParams::DTypeKV>::max();
73-
scale_pv = params.additional_params.scale_v[kv_head_idx] / p_scale;
74-
sm_scale_with_qk_log2 = params.additional_params.scale_q[qo_head_idx] *
75-
params.additional_params.scale_k[kv_head_idx] *
76-
params.additional_params.sm_scale * math::log2e;
73+
const float* scale_q_ptr = params.additional_params.maybe_scale_q;
74+
const float* scale_k_ptr = params.additional_params.maybe_scale_k;
75+
const float* scale_v_ptr = params.additional_params.maybe_scale_v;
76+
const float scale_q = scale_q_ptr ? scale_q_ptr[qo_head_idx] : 1.0f;
77+
const float scale_k = scale_k_ptr ? scale_k_ptr[kv_head_idx] : 1.0f;
78+
const float scale_v = scale_v_ptr ? scale_v_ptr[kv_head_idx] : 1.0f;
79+
scale_pv = scale_v / p_scale;
80+
sm_scale_with_qk_log2 = scale_q * scale_k * params.additional_params.sm_scale * math::log2e;
7781
}
7882

7983
template <int NUM_ROWS_PER_THREAD>

0 commit comments

Comments
 (0)