Skip to content

Commit ede67a3

Browse files
committed
Enable Hopper FA3 FP8 attention
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
1 parent 12b8ad0 commit ede67a3

File tree

3 files changed

+97
-21
lines changed

3 files changed

+97
-21
lines changed

flashinfer/decode.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
_get_range_buf,
5656
_unpack_paged_kv_cache,
5757
canonicalize_torch_dtype,
58+
determine_attention_backend,
5859
device_support_pdl,
5960
get_device_sm_count,
6061
is_float8,
@@ -710,7 +711,7 @@ def __init__(
710711
self._jit_module = get_batch_prefill_jit_module(
711712
jit_args[0],
712713
gen_customize_batch_prefill_module(
713-
"fa2", *jit_args
714+
backend, *jit_args
714715
).build_and_load(),
715716
)
716717
else:
@@ -822,6 +823,7 @@ def plan(
822823
logits_soft_cap: Optional[float] = None,
823824
q_data_type: Optional[Union[str, torch.dtype]] = "float16",
824825
kv_data_type: Optional[Union[str, torch.dtype]] = None,
826+
o_data_type: Optional[Union[str, torch.dtype]] = None,
825827
data_type: Optional[Union[str, torch.dtype]] = None,
826828
sm_scale: Optional[float] = None,
827829
rope_scale: Optional[float] = None,
@@ -869,6 +871,9 @@ def plan(
869871
kv_data_type : Optional[Union[str, torch.dtype]]
870872
The data type of the key/value tensor. If None, will be set to
871873
``q_data_type``. Defaults to ``None``.
874+
o_data_type : Optional[Union[str, torch.dtype]]
875+
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
876+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
872877
data_type: Optional[Union[str, torch.dtype]]
873878
The data type of both the query and key/value tensors. Defaults to torch.float16.
874879
data_type is deprecated, please use q_data_type and kv_data_type instead.
@@ -964,6 +969,10 @@ def plan(
964969
if kv_data_type is None:
965970
kv_data_type = q_data_type
966971
kv_data_type = canonicalize_torch_dtype(kv_data_type)
972+
if o_data_type is None:
973+
o_data_type = q_data_type
974+
o_data_type = canonicalize_torch_dtype(o_data_type)
975+
967976
if fixed_split_size is not None and not self.use_tensor_cores:
968977
raise ValueError(
969978
"fixed_split_size is only supported by tensor core decode for now."
@@ -973,6 +982,7 @@ def plan(
973982

974983
self._cached_q_data_type = q_data_type
975984
self._cached_kv_data_type = kv_data_type
985+
self._cached_o_data_type = o_data_type
976986
self._batch_size = batch_size
977987
self._num_qo_heads = num_qo_heads
978988
self._num_kv_heads = num_kv_heads
@@ -1012,7 +1022,7 @@ def plan(
10121022
self._cached_module = get_trtllm_gen_decode_module(
10131023
q_data_type,
10141024
kv_data_type,
1015-
q_data_type,
1025+
o_data_type,
10161026
indptr.dtype,
10171027
head_dim,
10181028
head_dim,
@@ -1027,11 +1037,20 @@ def plan(
10271037
if self._jit_module is not None:
10281038
self._cached_module = self._jit_module
10291039
else:
1040+
if self._backend == "auto":
1041+
self._backend = determine_attention_backend(
1042+
self.device,
1043+
PosEncodingMode[pos_encoding_mode].value,
1044+
False, # use_fp16_qk_reduction
1045+
False, # use_custom_mask
1046+
q_data_type,
1047+
kv_data_type,
1048+
)
10301049
self._cached_module = get_batch_prefill_module(
1031-
"fa2",
1050+
self._backend,
10321051
q_data_type,
10331052
kv_data_type,
1034-
q_data_type,
1053+
o_data_type,
10351054
indptr.dtype,
10361055
head_dim, # head_dim_qk
10371056
head_dim, # head_dim_vo
@@ -1041,7 +1060,13 @@ def plan(
10411060
False, # use_fp16_qk_reduction
10421061
)
10431062

1044-
self._plan_info = self._cached_module.plan(
1063+
if self._backend == "fa3" and q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]:
1064+
num_heads = max(num_qo_heads, num_kv_heads)
1065+
self._dummy_scales = torch.ones(num_heads, device=self.device, dtype=torch.float32)
1066+
else:
1067+
self._dummy_scales = None
1068+
1069+
args = [
10451070
self._float_workspace_buffer,
10461071
self._int_workspace_buffer,
10471072
self._pin_memory_int_workspace_buffer,
@@ -1058,9 +1083,13 @@ def plan(
10581083
head_dim,
10591084
False, # causal
10601085
window_left,
1061-
fixed_split_size,
1062-
disable_split_kv,
1063-
0, # num_colocated_ctas
1086+
]
1087+
if self._backend == "fa2":
1088+
args.append(fixed_split_size)
1089+
args.append(disable_split_kv)
1090+
args.append(0) # num_colocated_ctas
1091+
self._plan_info = self._cached_module.plan(
1092+
*args,
10641093
)
10651094
else:
10661095
if self._jit_module is not None:
@@ -1069,7 +1098,7 @@ def plan(
10691098
self._cached_module = get_batch_decode_module(
10701099
q_data_type,
10711100
kv_data_type,
1072-
q_data_type,
1101+
o_data_type,
10731102
indptr.dtype,
10741103
head_dim, # head_dim_qk
10751104
head_dim, # head_dim_vo
@@ -1278,9 +1307,13 @@ def run(
12781307
)
12791308

12801309
if out is None:
1281-
out = torch.empty_like(q)
1310+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
1311+
out = torch.empty(
1312+
q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device
1313+
)
12821314
else:
1283-
check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
1315+
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
1316+
check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out")
12841317

12851318
if self._backend == "trtllm-gen":
12861319
q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2))
@@ -1308,6 +1341,20 @@ def run(
13081341
if self._jit_module is not None:
13091342
run_args.extend(list(args))
13101343
else:
1344+
# Extract FP8 scale tensors from *args if q is FP8
1345+
fp8_scale_q = None
1346+
fp8_scale_k = None
1347+
fp8_scale_v = None
1348+
if is_float8(q) and len(args) >= 3:
1349+
fp8_scale_q = args[0]
1350+
fp8_scale_k = args[1]
1351+
fp8_scale_v = args[2]
1352+
if fp8_scale_q is None:
1353+
fp8_scale_q = self._dummy_scales
1354+
if fp8_scale_k is None:
1355+
fp8_scale_k = self._dummy_scales
1356+
if fp8_scale_v is None:
1357+
fp8_scale_v = self._dummy_scales
13111358
run_args += [
13121359
None, # packed_custom_mask
13131360
None, # mask_indptr_buf
@@ -1317,9 +1364,9 @@ def run(
13171364
None, # maybe_max_item_len_ptr
13181365
logits_soft_cap,
13191366
sm_scale,
1320-
None, # scale_q, not supported yet
1321-
None, # scale_k
1322-
None, # scale_v
1367+
fp8_scale_q,
1368+
fp8_scale_k,
1369+
fp8_scale_v,
13231370
rope_scale,
13241371
rope_theta,
13251372
0, # token_pos_in_items_len
@@ -2921,8 +2968,8 @@ def fast_decode_plan(
29212968
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)
29222969

29232970
try:
2924-
# Make sure we pass exactly 16 arguments for tensor core version
2925-
self._plan_info = self._cached_module.plan(
2971+
# Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
2972+
args = [
29262973
self._float_workspace_buffer,
29272974
self._int_workspace_buffer,
29282975
self._pin_memory_int_workspace_buffer,
@@ -2939,9 +2986,13 @@ def fast_decode_plan(
29392986
head_dim,
29402987
False, # causal
29412988
window_left,
2942-
fixed_split_size,
2943-
disable_split_kv,
2944-
0, # num_colocated_ctas
2989+
]
2990+
if self._backend == "fa2":
2991+
args.append(fixed_split_size)
2992+
args.append(disable_split_kv)
2993+
args.append(0) # num_colocated_ctas
2994+
self._plan_info = self._cached_module.plan(
2995+
*args,
29452996
)
29462997
except Exception as e:
29472998
raise RuntimeError(f"Error in standard plan: {e}") from e

flashinfer/jit/attention/modules.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,9 @@ def gen_batch_prefill_module(
974974
# KV-only quant is not influenced by this flag
975975
fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2]
976976

977+
assert backend in ["fa2", "fa3"], f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}"
978+
assert dtype_o not in [torch.float8_e4m3fn, torch.float8_e5m2], "FP8 output is not supported in fa2/fa3 backends yet"
979+
977980
if backend == "fa2":
978981
assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend"
979982
additional_tensor_names = [

flashinfer/prefill.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,7 +1650,7 @@ def plan(
16501650
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
16511651
o_data_type : Optional[Union[str, torch.dtype]]
16521652
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
1653-
For FP8 inputs, this should typically be set to torch.float16.
1653+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
16541654
non_blocking : bool
16551655
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
16561656
prefix_len_ptr :Optional[torch.Tensor]
@@ -1699,6 +1699,7 @@ def plan(
16991699
The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``.
17001700
"""
17011701
q_data_type = canonicalize_torch_dtype(q_data_type)
1702+
17021703
if kv_data_type is None:
17031704
kv_data_type = q_data_type
17041705
kv_data_type = canonicalize_torch_dtype(kv_data_type)
@@ -1867,6 +1868,14 @@ def plan(
18671868
self._backend, *get_module_args
18681869
)
18691870

1871+
# FA3 FP8 kernel requires scale_q/scale_k/scale_v to be mandatory device tensors.
1872+
# Reserve a dummy scale tensor in case users do not provide scale tensors in run().
1873+
if self._backend == "fa3" and q_data_type in [torch.float8_e4m3fn, torch.float8_e5m2]:
1874+
num_heads = max(num_qo_heads, num_kv_heads)
1875+
self._dummy_scales = torch.ones(num_heads, device=self.device, dtype=torch.float32)
1876+
else:
1877+
self._dummy_scales = None
1878+
18701879
self._block_tables = block_tables
18711880
if self._backend == "trtllm-gen":
18721881
assert logits_soft_cap == 0.0
@@ -2025,6 +2034,8 @@ def run(
20252034
20262035
*args
20272036
Additional arguments for custom kernels.
2037+
q_scale : Optional[float]
2038+
The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``.
20282039
k_scale : Optional[float]
20292040
The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``.
20302041
v_scale : Optional[float]
@@ -2053,6 +2064,11 @@ def run(
20532064
_check_cached_qkv_data_type(
20542065
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
20552066
)
2067+
o_dtype = self._cached_o_data_type
2068+
if out is not None and out.dtype != o_dtype:
2069+
raise ValueError(
2070+
f"The dtype of out {out.dtype} does not match the o_data_type {o_dtype} specified in plan function."
2071+
)
20562072

20572073
if self._kv_layout == "NHD":
20582074
page_size = k_cache.shape[1]
@@ -2175,6 +2191,12 @@ def run(
21752191
fp8_scale_q = args[0]
21762192
fp8_scale_k = args[1]
21772193
fp8_scale_v = args[2]
2194+
if fp8_scale_q is None:
2195+
fp8_scale_q = self._dummy_scales
2196+
if fp8_scale_k is None:
2197+
fp8_scale_k = self._dummy_scales
2198+
if fp8_scale_v is None:
2199+
fp8_scale_v = self._dummy_scales
21782200
run_args += [
21792201
self._custom_mask_buf,
21802202
self._mask_indptr_buf,
@@ -2592,7 +2614,7 @@ def plan(
25922614
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
25932615
o_data_type : Optional[Union[str, torch.dtype]]
25942616
The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
2595-
For FP8 inputs, this should typically be set to torch.float16.
2617+
For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
25962618
non_blocking : bool
25972619
Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
25982620
prefix_len_ptr :Optional[torch.Tensor]

0 commit comments

Comments
 (0)