Skip to content

Commit f61b1d4

Browse files
committed
Add third draft of mm_fp4 backend -- no audotune
1 parent 969f547 commit f61b1d4

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

benchmarks/routines/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def testMmFp4(args):
790790
run_refcheck = args.refcheck
791791
use_128x4_sf_layout = args.use_128x4_sf_layout
792792
use_nvfp4 = args.use_nvfp4
793-
autotune_supported_backends = ["cutlass", "trtllm"]
793+
autotune_supported_backends = ["cutlass", "trtllm", "auto"]
794794
res = []
795795

796796
backends = filter_backends_by_compute_capability(backends, args.routine, device)

flashinfer/gemm.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,7 +1883,6 @@ def _auto_gemm_fp4_requirement(
18831883
checker, "is_compute_capability_supported"
18841884
) and checker.is_compute_capability_supported(cc_arch):
18851885
# At least one backend is supported
1886-
print(f"Backend {candidate} is supported on this device.")
18871886
return True
18881887

18891888
# No backend is supported on this device
@@ -1990,8 +1989,9 @@ def mm_fp4(
19901989
if backend == "auto":
19911990
cuda_major, _ = get_cuda_version(a.device)
19921991
cc_major, cc_minor = get_compute_capability(a.device)
1993-
# If cuda version is 13 or greater AND cudnn version is 9.X or greater, prioritize cudnn.
1994-
if cuda_major >= 13: # to-do add cudnn version threshold
1992+
# If cuda version is 13 or greater:
1993+
# cudnn is more performant if cudnn version is 9.14 or greater.
1994+
if cuda_major >= 13 and cudnn.backend_version() >= 91400:
19951995
candidate_backends = ("cudnn", "cutlass")
19961996
# Otherwise, prioritize cutlass
19971997
else:
@@ -2022,11 +2022,7 @@ def mm_fp4(
20222022
supported_backends.append(candidate)
20232023
except Exception:
20242024
pass
2025-
print(f"Supported backends: {supported_backends}")
20262025
selected_backend = supported_backends[0]
2027-
print(
2028-
f"Selected backend: {selected_backend} for cuda version {cuda_major} and compute capability {cc_major}{cc_minor}"
2029-
)
20302026
else:
20312027
selected_backend = backend
20322028
if selected_backend == "cudnn":

0 commit comments

Comments
 (0)