@@ -413,7 +413,13 @@ def ragged_run(
413413 rope_scale : float ,
414414 rope_theta : float ,
415415 token_pos_in_items_len : int ,
416+ scale_q : Optional [torch .Tensor ] = None ,
417+ scale_k : Optional [torch .Tensor ] = None ,
418+ scale_v : Optional [torch .Tensor ] = None ,
416419 ) -> None :
420+ # Check if FP8 by presence of scale tensors
421+ is_fp8 = scale_q is not None
422+
417423 if backend == "fa2" :
418424 ragged_run_func (
419425 float_workspace_buffer ,
@@ -439,10 +445,34 @@ def ragged_run(
439445 logits_soft_cap ,
440446 sm_scale ,
441447 1.0 / rope_scale , # rope_rcp_scale
442- 1.0 / rope_theta , # rope_rcp_theta
448+ 1.0 / rope_theta , # rope_rcp_theta,
443449 token_pos_in_items_len ,
444450 )
451+ elif is_fp8 :
452+ # FA3 FP8: scale_q, scale_k, scale_v, sm_scale
453+ ragged_run_func (
454+ float_workspace_buffer ,
455+ int_workspace_buffer ,
456+ plan_info_vec ,
457+ q ,
458+ k ,
459+ v ,
460+ qo_indptr ,
461+ kv_indptr ,
462+ o ,
463+ maybe_lse ,
464+ mask_mode ,
465+ layout ,
466+ window_left ,
467+ enable_pdl ,
468+ scale_q ,
469+ scale_k ,
470+ scale_v ,
471+ sm_scale ,
472+ )
445473 else :
474+ # FA3 FP16: maybe_prefix_len_ptr, maybe_token_pos_in_items_ptr,
475+ # maybe_max_item_len_ptr, logits_soft_cap, sm_scale, token_pos_in_items_len
446476 ragged_run_func (
447477 float_workspace_buffer ,
448478 int_workspace_buffer ,
@@ -1533,6 +1563,7 @@ def plan(
15331563 rope_theta : Optional [float ] = None ,
15341564 q_data_type : Union [str , torch .dtype ] = "float16" ,
15351565 kv_data_type : Optional [Union [str , torch .dtype ]] = None ,
1566+ o_data_type : Optional [Union [str , torch .dtype ]] = None ,
15361567 non_blocking : bool = True ,
15371568 prefix_len_ptr : Optional [torch .Tensor ] = None ,
15381569 token_pos_in_items_ptr : Optional [torch .Tensor ] = None ,
@@ -1617,6 +1648,9 @@ def plan(
16171648 The data type of the query tensor, defaults torch.float16.
16181649 kv_data_type : Optional[Union[str, torch.dtype]]
16191650 The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
1651+ o_data_type : Optional[Union[str, torch.dtype]]
1652+ The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
1653+ For FP8 inputs, this should typically be set to torch.float16.
16201654 non_blocking : bool
16211655 Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
16221656 prefix_len_ptr :Optional[torch.Tensor]
@@ -1668,6 +1702,9 @@ def plan(
16681702 if kv_data_type is None :
16691703 kv_data_type = q_data_type
16701704 kv_data_type = canonicalize_torch_dtype (kv_data_type )
1705+ if o_data_type is None :
1706+ o_data_type = q_data_type
1707+ o_data_type = canonicalize_torch_dtype (o_data_type )
16711708
16721709 if logits_soft_cap is None :
16731710 logits_soft_cap = 0.0
@@ -1798,6 +1835,7 @@ def plan(
17981835
17991836 self ._cached_q_data_type = q_data_type
18001837 self ._cached_kv_data_type = kv_data_type
1838+ self ._cached_o_data_type = o_data_type
18011839
18021840 if self ._jit_module is not None :
18031841 self ._cached_module = self ._jit_module
@@ -1815,7 +1853,7 @@ def plan(
18151853 get_module_args = (
18161854 q_data_type ,
18171855 kv_data_type ,
1818- q_data_type ,
1856+ o_data_type ,
18191857 paged_kv_indptr .dtype ,
18201858 head_dim_qk ,
18211859 head_dim_vo ,
@@ -2052,12 +2090,15 @@ def run(
20522090 )
20532091
20542092 if out is None :
2093+ # Use cached output data type if available (for FP8 attention with FP16 output)
2094+ out_dtype = getattr (self , "_cached_o_data_type" , None ) or q .dtype
20552095 out = torch .empty (
2056- q .shape [:- 1 ] + v_cache .shape [- 1 :], dtype = q . dtype , device = q .device
2096+ q .shape [:- 1 ] + v_cache .shape [- 1 :], dtype = out_dtype , device = q .device
20572097 )
20582098 else :
2099+ out_dtype = getattr (self , "_cached_o_data_type" , None ) or q .dtype
20592100 check_shape_dtype_device (
2060- out , q .shape [:- 1 ] + v_cache .shape [- 1 :], q . dtype , q .device , "out"
2101+ out , q .shape [:- 1 ] + v_cache .shape [- 1 :], out_dtype , q .device , "out"
20612102 )
20622103
20632104 # Convert NHD layout to HND for trtllm-gen backend
@@ -2126,6 +2167,14 @@ def run(
21262167 if self ._jit_module is not None :
21272168 run_args .extend (list (args ))
21282169 else :
2170+ # Extract FP8 scale tensors from *args if q is FP8
2171+ fp8_scale_q = None
2172+ fp8_scale_k = None
2173+ fp8_scale_v = None
2174+ if is_float8 (q ) and len (args ) >= 3 :
2175+ fp8_scale_q = args [0 ]
2176+ fp8_scale_k = args [1 ]
2177+ fp8_scale_v = args [2 ]
21292178 run_args += [
21302179 self ._custom_mask_buf ,
21312180 self ._mask_indptr_buf ,
@@ -2135,9 +2184,9 @@ def run(
21352184 self ._max_item_len_ptr ,
21362185 logits_soft_cap ,
21372186 sm_scale ,
2138- None , # scale_q, not supported yet
2139- None , # scale_k
2140- None , # scale_v
2187+ fp8_scale_q ,
2188+ fp8_scale_k ,
2189+ fp8_scale_v ,
21412190 rope_scale ,
21422191 rope_theta ,
21432192 self ._token_pos_in_items_len ,
@@ -2466,6 +2515,7 @@ def plan(
24662515 rope_theta : Optional [float ] = None ,
24672516 q_data_type : Union [str , torch .dtype ] = "float16" ,
24682517 kv_data_type : Optional [Union [str , torch .dtype ]] = None ,
2518+ o_data_type : Optional [Union [str , torch .dtype ]] = None ,
24692519 non_blocking : bool = True ,
24702520 prefix_len_ptr : Optional [torch .Tensor ] = None ,
24712521 token_pos_in_items_ptr : Optional [torch .Tensor ] = None ,
@@ -2540,6 +2590,9 @@ def plan(
25402590 The data type of the query tensor, defaults to torch.float16.
25412591 kv_data_type : Optional[Union[str, torch.dtype]]
25422592 The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
2593+ o_data_type : Optional[Union[str, torch.dtype]]
2594+ The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
2595+ For FP8 inputs, this should typically be set to torch.float16.
25432596 non_blocking : bool
25442597 Whether to copy the input tensors to the device asynchronously, defaults to ``True``.
25452598 prefix_len_ptr :Optional[torch.Tensor]
@@ -2580,6 +2633,9 @@ def plan(
25802633 if kv_data_type is None :
25812634 kv_data_type = q_data_type
25822635 kv_data_type = canonicalize_torch_dtype (kv_data_type )
2636+ if o_data_type is None :
2637+ o_data_type = q_data_type
2638+ o_data_type = canonicalize_torch_dtype (o_data_type )
25832639 if head_dim_vo is None :
25842640 head_dim_vo = head_dim_qk
25852641 if fixed_split_size is None :
@@ -2652,6 +2708,7 @@ def plan(
26522708
26532709 self ._cached_q_data_type = q_data_type
26542710 self ._cached_kv_data_type = kv_data_type
2711+ self ._cached_o_data_type = o_data_type
26552712 kv_len_arr = kv_indptr_host [1 :] - kv_indptr_host [:- 1 ]
26562713
26572714 self ._prefix_len_ptr = prefix_len_ptr
@@ -2675,7 +2732,7 @@ def plan(
26752732 get_module_args = (
26762733 q_data_type ,
26772734 kv_data_type ,
2678- q_data_type ,
2735+ o_data_type ,
26792736 kv_indptr .dtype ,
26802737 head_dim_qk ,
26812738 head_dim_vo ,
@@ -2862,11 +2919,17 @@ def run(
28622919 )
28632920 if out is None :
28642921 out = torch .empty (
2865- q .shape [:- 1 ] + v .shape [- 1 :], dtype = q .dtype , device = q .device
2922+ q .shape [:- 1 ] + v .shape [- 1 :],
2923+ dtype = self ._cached_o_data_type ,
2924+ device = q .device ,
28662925 )
28672926 else :
28682927 check_shape_dtype_device (
2869- out , q .shape [:- 1 ] + v .shape [- 1 :], q .dtype , q .device , "out"
2928+ out ,
2929+ q .shape [:- 1 ] + v .shape [- 1 :],
2930+ self ._cached_o_data_type ,
2931+ q .device ,
2932+ "out" ,
28702933 )
28712934 if self ._backend == "cutlass" :
28722935 out , lse = fmha_varlen (
@@ -2884,7 +2947,9 @@ def run(
28842947 )
28852948 return (out , lse ) if return_lse else out
28862949
2887- if is_float8 (q ):
2950+ # Skip FP8->FP16 conversion for FA3 backend with FP8 support
2951+ # The JIT module will handle FP8 natively
2952+ if is_float8 (q ) and self ._backend != "fa3" :
28882953 logging .warning (
28892954 "Our current prefill kernel implementation needs f16 input, the f8 inputs "
28902955 " are casted to f16, which could result in performance degradation."
@@ -2933,6 +2998,9 @@ def run(
29332998 rope_theta ,
29342999 self ._token_pos_in_items_len ,
29353000 ]
3001+ # For FP8, append scale tensors
3002+ if is_float8 (q ):
3003+ run_args .extend (list (args )) # scale_q, scale_k, scale_v
29363004
29373005 assert self ._cached_module is not None , "cached module is not initialized"
29383006 self ._cached_module .ragged_run (* run_args )
0 commit comments