@@ -1271,6 +1271,7 @@ def build_plans_cudnn_fp4_gemm_graph(
12711271 device ,
12721272 alpha ,
12731273 use_nvfp4 ,
1274+ tactic : int = - 1 ,
12741275):
12751276 graph = create_cudnn_execution_plans_fp4_gemm (
12761277 a_shape ,
@@ -1290,7 +1291,10 @@ def build_plans_cudnn_fp4_gemm_graph(
12901291 )
12911292
12921293 graph .check_support ()
1293- graph .build_plans ()
1294+ if tactic != - 1 :
1295+ graph .build_plan_at_index (tactic )
1296+ else :
1297+ graph .build_plans ()
12941298 return graph
12951299
12961300
@@ -1303,6 +1307,7 @@ def execute_cudnn_gemm_fp4_graph(
13031307 alpha ,
13041308 c_final ,
13051309 workspace_buffer ,
1310+ tactic : int = - 1 ,
13061311):
13071312 variant_pack = {
13081313 UIDs .A_UID .value : a .view (get_native_fp4_dtype ()),
@@ -1322,7 +1327,12 @@ def execute_cudnn_gemm_fp4_graph(
13221327
13231328 stream = torch .cuda .current_stream (a .device )
13241329
1325- graph .execute (variant_pack , workspace_buffer , handle = _get_cudnn_handle (stream ))
1330+ if tactic == - 1 :
1331+ graph .execute (variant_pack , workspace_buffer , handle = _get_cudnn_handle (stream ))
1332+ else :
1333+ graph .execute_plan_at_index (
1334+ variant_pack , workspace_buffer , tactic , handle = _get_cudnn_handle (stream )
1335+ )
13261336
13271337
13281338@functools .cache
@@ -1667,6 +1677,7 @@ def _cudnn_gemm_fp4(
16671677 block_size : int = 16 ,
16681678 use_nvfp4 : bool = True ,
16691679 workspace_buffer : torch .Tensor = None ,
1680+ tactic : int = - 1 ,
16701681):
16711682 _check_cudnn_availability ()
16721683 # the fp4 cudnn graph will be shared for both mm and bmm, so
@@ -1698,11 +1709,12 @@ def _cudnn_gemm_fp4(
16981709 a .device ,
16991710 alpha is not None ,
17001711 use_nvfp4 ,
1712+ tactic = tactic ,
17011713 )
17021714
17031715 # execute the fp4 cudnn graph
17041716 execute_cudnn_gemm_fp4_graph (
1705- graph , a , b , a_descale , b_descale , alpha , out , workspace_buffer
1717+ graph , a , b , a_descale , b_descale , alpha , out , workspace_buffer , tactic = tactic
17061718 )
17071719
17081720
@@ -1714,7 +1726,48 @@ def get_valid_tactics(
17141726 profile : OptimizationProfile ,
17151727 ) -> List [int ]:
17161728 # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic
1717- return [0 ]
1729+ _check_cudnn_availability ()
1730+ (
1731+ a ,
1732+ b ,
1733+ a_descale ,
1734+ b_descale ,
1735+ alpha ,
1736+ out_dtype ,
1737+ out ,
1738+ block_size ,
1739+ use_nvfp4 ,
1740+ workspace_buffer ,
1741+ ) = inputs
1742+
1743+ real_a_shape , real_a_stride = _get_real_fp4_shape_from_packed_uint8 (a )
1744+ real_b_shape , real_b_stride = _get_real_fp4_shape_from_packed_uint8 (b )
1745+ batch = real_a_shape [0 ]
1746+ expanded_a_descale_shape , expanded_a_descale_stride = (
1747+ _expand_block_scale_tensor_shape (a_descale , batch )
1748+ )
1749+ expanded_b_descale_shape , expanded_b_descale_stride = (
1750+ _expand_block_scale_tensor_shape (b_descale , batch )
1751+ )
1752+
1753+ graph = build_plans_cudnn_fp4_gemm_graph (
1754+ real_a_shape ,
1755+ real_a_stride ,
1756+ real_b_shape ,
1757+ real_b_stride ,
1758+ expanded_a_descale_shape ,
1759+ expanded_a_descale_stride ,
1760+ expanded_b_descale_shape ,
1761+ expanded_b_descale_stride ,
1762+ cudnn .data_type .FP4_E2M1 ,
1763+ _torch_data_type_to_cudnn_data_type (out_dtype ),
1764+ block_size ,
1765+ a .device ,
1766+ alpha is not None ,
1767+ use_nvfp4 ,
1768+ )
1769+ num_plans = graph .get_execution_plan_count ()
1770+ return list (range (num_plans ))
17181771
17191772 def forward (
17201773 self ,
@@ -1746,6 +1799,7 @@ def forward(
17461799 block_size ,
17471800 use_nvfp4 ,
17481801 workspace_buffer ,
1802+ tactic = tactic ,
17491803 )
17501804
17511805 return CudnnFp4GemmRunner ()
0 commit comments