5555 _get_range_buf ,
5656 _unpack_paged_kv_cache ,
5757 canonicalize_torch_dtype ,
58+ determine_attention_backend ,
5859 device_support_pdl ,
5960 get_device_sm_count ,
6061 is_float8 ,
@@ -710,7 +711,7 @@ def __init__(
710711 self ._jit_module = get_batch_prefill_jit_module (
711712 jit_args [0 ],
712713 gen_customize_batch_prefill_module (
713- "fa2" , * jit_args
714+ backend , * jit_args
714715 ).build_and_load (),
715716 )
716717 else :
@@ -822,6 +823,7 @@ def plan(
822823 logits_soft_cap : Optional [float ] = None ,
823824 q_data_type : Optional [Union [str , torch .dtype ]] = "float16" ,
824825 kv_data_type : Optional [Union [str , torch .dtype ]] = None ,
826+ o_data_type : Optional [Union [str , torch .dtype ]] = None ,
825827 data_type : Optional [Union [str , torch .dtype ]] = None ,
826828 sm_scale : Optional [float ] = None ,
827829 rope_scale : Optional [float ] = None ,
@@ -869,6 +871,9 @@ def plan(
869871 kv_data_type : Optional[Union[str, torch.dtype]]
870872 The data type of the key/value tensor. If None, will be set to
871873 ``q_data_type``. Defaults to ``None``.
874+ o_data_type : Optional[Union[str, torch.dtype]]
875+ The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
876+ For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
872877 data_type: Optional[Union[str, torch.dtype]]
873878 The data type of both the query and key/value tensors. Defaults to torch.float16.
874879 data_type is deprecated, please use q_data_type and kv_data_type instead.
@@ -964,6 +969,10 @@ def plan(
964969 if kv_data_type is None :
965970 kv_data_type = q_data_type
966971 kv_data_type = canonicalize_torch_dtype (kv_data_type )
972+ if o_data_type is None :
973+ o_data_type = q_data_type
974+ o_data_type = canonicalize_torch_dtype (o_data_type )
975+
967976 if fixed_split_size is not None and not self .use_tensor_cores :
968977 raise ValueError (
969978 "fixed_split_size is only supported by tensor core decode for now."
@@ -973,6 +982,7 @@ def plan(
973982
974983 self ._cached_q_data_type = q_data_type
975984 self ._cached_kv_data_type = kv_data_type
985+ self ._cached_o_data_type = o_data_type
976986 self ._batch_size = batch_size
977987 self ._num_qo_heads = num_qo_heads
978988 self ._num_kv_heads = num_kv_heads
@@ -1012,7 +1022,7 @@ def plan(
10121022 self ._cached_module = get_trtllm_gen_decode_module (
10131023 q_data_type ,
10141024 kv_data_type ,
1015- q_data_type ,
1025+ o_data_type ,
10161026 indptr .dtype ,
10171027 head_dim ,
10181028 head_dim ,
@@ -1027,11 +1037,20 @@ def plan(
10271037 if self ._jit_module is not None :
10281038 self ._cached_module = self ._jit_module
10291039 else :
1040+ if self ._backend == "auto" :
1041+ self ._backend = determine_attention_backend (
1042+ self .device ,
1043+ PosEncodingMode [pos_encoding_mode ].value ,
1044+ False , # use_fp16_qk_reduction
1045+ False , # use_custom_mask
1046+ q_data_type ,
1047+ kv_data_type ,
1048+ )
10301049 self ._cached_module = get_batch_prefill_module (
1031- "fa2" ,
1050+ self . _backend ,
10321051 q_data_type ,
10331052 kv_data_type ,
1034- q_data_type ,
1053+ o_data_type ,
10351054 indptr .dtype ,
10361055 head_dim , # head_dim_qk
10371056 head_dim , # head_dim_vo
@@ -1041,7 +1060,7 @@ def plan(
10411060 False , # use_fp16_qk_reduction
10421061 )
10431062
1044- self . _plan_info = self . _cached_module . plan (
1063+ args = [
10451064 self ._float_workspace_buffer ,
10461065 self ._int_workspace_buffer ,
10471066 self ._pin_memory_int_workspace_buffer ,
@@ -1058,9 +1077,13 @@ def plan(
10581077 head_dim ,
10591078 False , # causal
10601079 window_left ,
1061- fixed_split_size ,
1062- disable_split_kv ,
1063- 0 , # num_colocated_ctas
1080+ ]
1081+ if self ._backend == "fa2" :
1082+ args .append (fixed_split_size )
1083+ args .append (disable_split_kv )
1084+ args .append (0 ) # num_colocated_ctas
1085+ self ._plan_info = self ._cached_module .plan (
1086+ * args ,
10641087 )
10651088 else :
10661089 if self ._jit_module is not None :
@@ -1069,7 +1092,7 @@ def plan(
10691092 self ._cached_module = get_batch_decode_module (
10701093 q_data_type ,
10711094 kv_data_type ,
1072- q_data_type ,
1095+ o_data_type ,
10731096 indptr .dtype ,
10741097 head_dim , # head_dim_qk
10751098 head_dim , # head_dim_vo
@@ -1278,9 +1301,13 @@ def run(
12781301 )
12791302
12801303 if out is None :
1281- out = torch .empty_like (q )
1304+ out_dtype = getattr (self , "_cached_o_data_type" , None ) or q .dtype
1305+ out = torch .empty (
1306+ q .shape [:- 1 ] + v_cache .shape [- 1 :], dtype = out_dtype , device = q .device
1307+ )
12821308 else :
1283- check_shape_dtype_device (out , q .shape , q .dtype , q .device , "out" )
1309+ out_dtype = getattr (self , "_cached_o_data_type" , None ) or q .dtype
1310+ check_shape_dtype_device (out , q .shape , out_dtype , q .device , "out" )
12841311
12851312 if self ._backend == "trtllm-gen" :
12861313 q = q .view (q .size (0 ) // q_len_per_req , q_len_per_req , q .size (1 ), q .size (2 ))
@@ -1308,6 +1335,14 @@ def run(
13081335 if self ._jit_module is not None :
13091336 run_args .extend (list (args ))
13101337 else :
1338+ # Extract FP8 scale tensors from *args if q is FP8
1339+ fp8_scale_q = None
1340+ fp8_scale_k = None
1341+ fp8_scale_v = None
1342+ if is_float8 (q ) and len (args ) >= 3 :
1343+ fp8_scale_q = args [0 ]
1344+ fp8_scale_k = args [1 ]
1345+ fp8_scale_v = args [2 ]
13111346 run_args += [
13121347 None , # packed_custom_mask
13131348 None , # mask_indptr_buf
@@ -1317,9 +1352,9 @@ def run(
13171352 None , # maybe_max_item_len_ptr
13181353 logits_soft_cap ,
13191354 sm_scale ,
1320- None , # scale_q, not supported yet
1321- None , # scale_k
1322- None , # scale_v
1355+ fp8_scale_q ,
1356+ fp8_scale_k ,
1357+ fp8_scale_v ,
13231358 rope_scale ,
13241359 rope_theta ,
13251360 0 , # token_pos_in_items_len
@@ -1372,7 +1407,7 @@ def run(
13721407 ]
13731408
13741409 self ._cached_module .run (* run_args )
1375- if v_scale is not None :
1410+ if v_scale is not None and v_scale != 1.0 :
13761411 # TODO(Zihao): fused into kernel
13771412 if is_float8 (out ):
13781413 out = (out .to (torch .float32 ) * v_scale ).to (out .dtype )
@@ -2921,8 +2956,8 @@ def fast_decode_plan(
29212956 kv_lens_arr_host = get_seq_lens (indptr_host , last_page_len_host , page_size )
29222957
29232958 try :
2924- # Make sure we pass exactly 16 arguments for tensor core version
2925- self . _plan_info = self . _cached_module . plan (
2959+ # Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
2960+ args = [
29262961 self ._float_workspace_buffer ,
29272962 self ._int_workspace_buffer ,
29282963 self ._pin_memory_int_workspace_buffer ,
@@ -2939,9 +2974,13 @@ def fast_decode_plan(
29392974 head_dim ,
29402975 False , # causal
29412976 window_left ,
2942- fixed_split_size ,
2943- disable_split_kv ,
2944- 0 , # num_colocated_ctas
2977+ ]
2978+ if self ._backend == "fa2" :
2979+ args .append (fixed_split_size )
2980+ args .append (disable_split_kv )
2981+ args .append (0 ) # num_colocated_ctas
2982+ self ._plan_info = self ._cached_module .plan (
2983+ * args ,
29452984 )
29462985 except Exception as e :
29472986 raise RuntimeError (f"Error in standard plan: { e } " ) from e
0 commit comments