Skip to content

Commit e7b06f6

Browse files
committed
Enable Hopper FA3 FP8 attention
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
1 parent 12b8ad0 commit e7b06f6

File tree

3 files changed

+95
-21
lines changed

3 files changed

+95
-21
lines changed

flashinfer/decode.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
_get_range_buf,
5656
_unpack_paged_kv_cache,
5757
canonicalize_torch_dtype,
58+
determine_attention_backend,
5859
device_support_pdl,
5960
get_device_sm_count,
6061
is_float8,
@@ -710,7 +711,7 @@ def __init__(
710711
self._jit_module = get_batch_prefill_jit_module(
711712
jit_args[0],
712713
gen_customize_batch_prefill_module(
713-
"fa2", *jit_args
714+
backend, *jit_args
714715
).build_and_load(),
715716
)
716717
else:
@@ -822,6 +823,7 @@ def plan(
822823
logits_soft_cap: Optional[float] = None,
823824
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
824825
kv_data_type: Optional[Union[str, torch.dtype]] = None,
826+
o_data_type: Optional[Union[str, torch.dtype]] = None,
825827
data_type: Optional[Union[str, torch.dtype]] = None,
826828
sm_scale: Optional[float] = None,
827829
rope_scale: Optional[float] = None,
@@ -869,6 +871,9 @@ def plan(
869871
kv_data_type : Optional[Union[str, torch.dtype]]
870872
The data type of the key/value tensor. If None, will be set to
871873
``q_data_type``. Defaults to ``None``.
874+
o_data_type : Optional[Union[str, torch.dtype]]
875+
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
876+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
872877
data_type: Optional[Union[str, torch.dtype]]
873878
The data type of both the query and key/value tensors. Defaults to torch.float16.
874879
data_type is deprecated, please use q_data_type and kv_data_type instead.
@@ -964,6 +969,10 @@ def plan(
964969
if kv_data_type is None:
965970
kv_data_type = q_data_type
966971
kv_data_type = canonicalize_torch_dtype(kv_data_type)
972+
if o_data_type is None:
973+
o_data_type = q_data_type
974+
o_data_type = canonicalize_torch_dtype(o_data_type)
975+
967976
if fixed_split_size is not None and not self.use_tensor_cores:
968977
raise ValueError(
969978
"fixed_split_size is only supported by tensor core decode for now."
@@ -973,6 +982,7 @@ def plan(
973982

974983
self._cached_q_data_type = q_data_type
975984
self._cached_kv_data_type = kv_data_type
985+
self._cached_o_data_type = o_data_type
976986
self._batch_size = batch_size
977987
self._num_qo_heads = num_qo_heads
978988
self._num_kv_heads = num_kv_heads
@@ -1012,7 +1022,7 @@ def plan(
10121022
self._cached_module = get_trtllm_gen_decode_module(
10131023
q_data_type,
10141024
kv_data_type,
1015-
q_data_type,
1025+
o_data_type,
10161026
indptr.dtype,
10171027
head_dim,
10181028
head_dim,
@@ -1027,11 +1037,20 @@ def plan(
10271037
if self._jit_module is not None:
10281038
self._cached_module = self._jit_module
10291039
else:
1040+
if self._backend == "auto":
1041+
self._backend = determine_attention_backend(
1042+
self.device,
1043+
PosEncodingMode[pos_encoding_mode].value,
1044+
False, # use_fp16_qk_reduction
1045+
False, # use_custom_mask
1046+
q_data_type,
1047+
kv_data_type,
1048+
)
10301049
self._cached_module = get_batch_prefill_module(
1031-
"fa2",
1050+
self._backend,
10321051
q_data_type,
10331052
kv_data_type,
1034-
q_data_type,
1053+
o_data_type,
10351054
indptr.dtype,
10361055
head_dim, # head_dim_qk
10371056
head_dim, # head_dim_vo
@@ -1041,7 +1060,12 @@ def plan(
10411060
False, # use_fp16_qk_reduction
10421061
)
10431062

1044-
self._plan_info = self._cached_module.plan(
1063+
if self._backend == "fa3" and q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]:
1064+
self._dummy_scale = torch.ones(1, device=self.device, dtype=torch.float32)
1065+
else:
1066+
self._dummy_scale = None
1067+
1068+
args = [
10451069
self._float_workspace_buffer,
10461070
self._int_workspace_buffer,
10471071
self._pin_memory_int_workspace_buffer,
@@ -1058,9 +1082,13 @@ def plan(
10581082
head_dim,
10591083
False, # causal
10601084
window_left,
1061-
fixed_split_size,
1062-
disable_split_kv,
1063-
0, # num_colocated_ctas
1085+
]
1086+
if self._backend == "fa2":
1087+
args.append(fixed_split_size)
1088+
args.append(disable_split_kv)
1089+
args.append(0) # num_colocated_ctas
1090+
self._plan_info = self._cached_module.plan(
1091+
*args,
10641092
)
10651093
else:
10661094
if self._jit_module is not None:
@@ -1069,7 +1097,7 @@ def plan(
10691097
self._cached_module = get_batch_decode_module(
10701098
q_data_type,
10711099
kv_data_type,
1072-
q_data_type,
1100+
o_data_type,
10731101
indptr.dtype,
10741102
head_dim, # head_dim_qk
10751103
head_dim, # head_dim_vo
@@ -1278,9 +1306,13 @@ def run(
12781306
)
12791307

12801308
if out is None:
1281-
out = torch.empty_like(q)
1309+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
1310+
out = torch.empty(
1311+
q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device
1312+
)
12821313
else:
1283-
check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
1314+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
1315+
check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out")
12841316

12851317
if self._backend == "trtllm-gen":
12861318
q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2))
@@ -1308,6 +1340,20 @@ def run(
13081340
if self._jit_module is not None:
13091341
run_args.extend(list(args))
13101342
else:
1343+
# Extract FP8 scale tensors from *args if q is FP8
1344+
fp8_scale_q = None
1345+
fp8_scale_k = None
1346+
fp8_scale_v = None
1347+
if is_float8(q) and len(args) >= 3:
1348+
fp8_scale_q = args[0]
1349+
fp8_scale_k = args[1]
1350+
fp8_scale_v = args[2]
1351+
if fp8_scale_q is None:
1352+
fp8_scale_q = self._dummy_scale
1353+
if fp8_scale_k is None:
1354+
fp8_scale_k = self._dummy_scale
1355+
if fp8_scale_v is None:
1356+
fp8_scale_v = self._dummy_scale
13111357
run_args += [
13121358
None, # packed_custom_mask
13131359
None, # mask_indptr_buf
@@ -1317,9 +1363,9 @@ def run(
13171363
None, # maybe_max_item_len_ptr
13181364
logits_soft_cap,
13191365
sm_scale,
1320-
None, # scale_q, not supported yet
1321-
None, # scale_k
1322-
None, # scale_v
1366+
fp8_scale_q,
1367+
fp8_scale_k,
1368+
fp8_scale_v,
13231369
rope_scale,
13241370
rope_theta,
13251371
0, # token_pos_in_items_len
@@ -2921,8 +2967,8 @@ def fast_decode_plan(
29212967
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
29222968

29232969
try:
2924-
# Make sure we pass exactly 16 arguments for tensor core version
2925-
self._plan_info = self._cached_module.plan(
2970+
# Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
2971+
args = [
29262972
self._float_workspace_buffer,
29272973
self._int_workspace_buffer,
29282974
self._pin_memory_int_workspace_buffer,
@@ -2939,9 +2985,13 @@ def fast_decode_plan(
29392985
head_dim,
29402986
False, # causal
29412987
window_left,
2942-
fixed_split_size,
2943-
disable_split_kv,
2944-
0, # num_colocated_ctas
2988+
]
2989+
if self._backend == "fa2":
2990+
args.append(fixed_split_size)
2991+
args.append(disable_split_kv)
2992+
args.append(0) # num_colocated_ctas
2993+
self._plan_info = self._cached_module.plan(
2994+
*args,
29452995
)
29462996
except Exception as e:
29472997
raise RuntimeError(f"Error in standard plan: {e}") from e

flashinfer/jit/attention/modules.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,9 @@ def gen_batch_prefill_module(
974974
# KV-only quant is not influenced by this flag
975975
fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2]
976976

977+
assert backend in ["fa2", "fa3"], f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}"
978+
assert dtype_o not in [torch.float8_e4m3fn, torch.float8_e5m2], "FP8 output is not supported in fa2/fa3 backends yet"
979+
977980
if backend == "fa2":
978981
assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend"
979982
additional_tensor_names = [

flashinfer/prefill.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,7 @@ def plan(
16501650
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
16511651
o_data_type : Optional[Union[str, torch.dtype]]
16521652
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
1653-
For FP8 inputs, this should typically be set to torch.float16.
1653+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
16541654
non_blocking : bool
16551655
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
16561656
prefix_len_ptr :Optional[torch.Tensor]
@@ -1699,6 +1699,7 @@ def plan(
16991699
The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
17001700
"""
17011701
q_data_type = canonicalize_torch_dtype(q_data_type)
1702+
17021703
if kv_data_type is None:
17031704
kv_data_type = q_data_type
17041705
kv_data_type = canonicalize_torch_dtype(kv_data_type)
@@ -1867,6 +1868,13 @@ def plan(
18671868
self._backend, *get_module_args
18681869
)
18691870

1871+
# FA3 FP8 kernel requires scale_q/scale_k/scale_v to be mandatory device tensors.
1872+
# Reserve a dummy scale tensor in case users do not provide scale tensors in run().
1873+
if self._backend == "fa3" and q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]:
1874+
self._dummy_scale = torch.ones(1, device=self.device, dtype=torch.float32)
1875+
else:
1876+
self._dummy_scale = None
1877+
18701878
self._block_tables = block_tables
18711879
if self._backend == "trtllm-gen":
18721880
assert logits_soft_cap == 0.0
@@ -2025,6 +2033,8 @@ def run(
20252033
20262034
*args
20272035
Additional arguments for custom kernels.
2036+
q_scale : Optional[float]
2037+
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
20282038
k_scale : Optional[float]
20292039
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
20302040
v_scale : Optional[float]
@@ -2053,6 +2063,11 @@ def run(
20532063
_check_cached_qkv_data_type(
20542064
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
20552065
)
2066+
o_dtype = self._cached_o_data_type
2067+
if out is not None and out.dtype != o_dtype:
2068+
raise ValueError(
2069+
f"The dtype of out {out.dtype} does not match the o_data_type {o_dtype} specified in plan function."
2070+
)
20562071

20572072
if self._kv_layout == "NHD":
20582073
page_size = k_cache.shape[1]
@@ -2175,6 +2190,12 @@ def run(
21752190
fp8_scale_q = args[0]
21762191
fp8_scale_k = args[1]
21772192
fp8_scale_v = args[2]
2193+
if fp8_scale_q is None:
2194+
fp8_scale_q = self._dummy_scale
2195+
if fp8_scale_k is None:
2196+
fp8_scale_k = self._dummy_scale
2197+
if fp8_scale_v is None:
2198+
fp8_scale_v = self._dummy_scale
21782199
run_args += [
21792200
self._custom_mask_buf,
21802201
self._mask_indptr_buf,
@@ -2592,7 +2613,7 @@ def plan(
25922613
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
25932614
o_data_type : Optional[Union[str, torch.dtype]]
25942615
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
2595-
For FP8 inputs, this should typically be set to torch.float16.
2616+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
25962617
non_blocking : bool
25972618
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
25982619
prefix_len_ptr :Optional[torch.Tensor]

0 commit comments

Comments
 (0)