Skip to content

Commit 877eaf0

Browse files
committed
Add 5th draft of mm_fp4 backend -- enable cudnn autotune
1 parent 793c634 commit 877eaf0

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

flashinfer/gemm.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,7 @@ def build_plans_cudnn_fp4_gemm_graph(
12711271
device,
12721272
alpha,
12731273
use_nvfp4,
1274+
tactic: int = -1,
12741275
):
12751276
graph = create_cudnn_execution_plans_fp4_gemm(
12761277
a_shape,
@@ -1290,7 +1291,10 @@ def build_plans_cudnn_fp4_gemm_graph(
12901291
)
12911292

12921293
graph.check_support()
1293-
graph.build_plans()
1294+
if tactic != -1:
1295+
graph.build_plan_at_index(tactic)
1296+
else:
1297+
graph.build_plans()
12941298
return graph
12951299

12961300

@@ -1303,6 +1307,7 @@ def execute_cudnn_gemm_fp4_graph(
13031307
alpha,
13041308
c_final,
13051309
workspace_buffer,
1310+
tactic: int = -1,
13061311
):
13071312
variant_pack = {
13081313
UIDs.A_UID.value: a.view(get_native_fp4_dtype()),
@@ -1322,7 +1327,12 @@ def execute_cudnn_gemm_fp4_graph(
13221327

13231328
stream = torch.cuda.current_stream(a.device)
13241329

1325-
graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream))
1330+
if tactic == -1:
1331+
graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream))
1332+
else:
1333+
graph.execute_plan_at_index(
1334+
variant_pack, workspace_buffer, tactic, handle=_get_cudnn_handle(stream)
1335+
)
13261336

13271337

13281338
@functools.cache
@@ -1667,6 +1677,7 @@ def _cudnn_gemm_fp4(
16671677
block_size: int = 16,
16681678
use_nvfp4: bool = True,
16691679
workspace_buffer: torch.Tensor = None,
1680+
tactic: int = -1,
16701681
):
16711682
_check_cudnn_availability()
16721683
# the fp4 cudnn graph will be shared for both mm and bmm, so
@@ -1698,11 +1709,12 @@ def _cudnn_gemm_fp4(
16981709
a.device,
16991710
alpha is not None,
17001711
use_nvfp4,
1712+
tactic=tactic,
17011713
)
17021714

17031715
# execute the fp4 cudnn graph
17041716
execute_cudnn_gemm_fp4_graph(
1705-
graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer
1717+
graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer, tactic=tactic
17061718
)
17071719

17081720

@@ -1714,7 +1726,48 @@ def get_valid_tactics(
17141726
profile: OptimizationProfile,
17151727
) -> List[int]:
17161728
# cudnn has heuristic for fp4 gemm, so we only need to use the default tactic
1717-
return [0]
1729+
_check_cudnn_availability()
1730+
(
1731+
a,
1732+
b,
1733+
a_descale,
1734+
b_descale,
1735+
alpha,
1736+
out_dtype,
1737+
out,
1738+
block_size,
1739+
use_nvfp4,
1740+
workspace_buffer,
1741+
) = inputs
1742+
1743+
real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
1744+
real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
1745+
batch = real_a_shape[0]
1746+
expanded_a_descale_shape, expanded_a_descale_stride = (
1747+
_expand_block_scale_tensor_shape(a_descale, batch)
1748+
)
1749+
expanded_b_descale_shape, expanded_b_descale_stride = (
1750+
_expand_block_scale_tensor_shape(b_descale, batch)
1751+
)
1752+
1753+
graph = build_plans_cudnn_fp4_gemm_graph(
1754+
real_a_shape,
1755+
real_a_stride,
1756+
real_b_shape,
1757+
real_b_stride,
1758+
expanded_a_descale_shape,
1759+
expanded_a_descale_stride,
1760+
expanded_b_descale_shape,
1761+
expanded_b_descale_stride,
1762+
cudnn.data_type.FP4_E2M1,
1763+
_torch_data_type_to_cudnn_data_type(out_dtype),
1764+
block_size,
1765+
a.device,
1766+
alpha is not None,
1767+
use_nvfp4,
1768+
)
1769+
num_plans = graph.get_execution_plan_count()
1770+
return list(range(num_plans))
17181771

17191772
def forward(
17201773
self,
@@ -1746,6 +1799,7 @@ def forward(
17461799
block_size,
17471800
use_nvfp4,
17481801
workspace_buffer,
1802+
tactic=tactic,
17491803
)
17501804

17511805
return CudnnFp4GemmRunner()

tests/gemm/test_mm_fp4.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ def test_mm_fp4(
4040
pytest.skip("trtllm gemm does not support SM110/SM120/SM121 GPUs.")
4141
if not use_128x4_sf_layout and backend != "trtllm":
4242
pytest.skip("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False")
43-
if auto_tuning and backend == "cudnn":
44-
pytest.skip("Skipping test for cudnn fp4 with auto_tuning=True")
4543
if not use_nvfp4 and backend != "cudnn":
4644
pytest.skip("mx_fp4 is only supported for cudnn backend")
4745

0 commit comments

Comments
 (0)