Skip to content

Commit 80f76e0

Browse files
committed
Add third draft of mm_fp4 backend -- no audotune
1 parent 805487c commit 80f76e0

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

benchmarks/routines/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def testMmFp4(args):
790790
run_refcheck = args.refcheck
791791
use_128x4_sf_layout = args.use_128x4_sf_layout
792792
use_nvfp4 = args.use_nvfp4
793-
autotune_supported_backends = ["cutlass", "trtllm"]
793+
autotune_supported_backends = ["cutlass", "trtllm", "auto"]
794794
res = []
795795

796796
backends = filter_backends_by_compute_capability(backends, args.routine, device)

flashinfer/gemm.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,7 +1887,6 @@ def _auto_gemm_fp4_requirement(
18871887
checker, "is_compute_capability_supported"
18881888
) and checker.is_compute_capability_supported(cc_arch):
18891889
# At least one backend is supported
1890-
print(f"Backend {candidate} is supported on this device.")
18911890
return True
18921891

18931892
# No backend is supported on this device
@@ -1994,8 +1993,9 @@ def mm_fp4(
19941993
if backend == "auto":
19951994
cuda_major, _ = get_cuda_version(a.device)
19961995
cc_major, cc_minor = get_compute_capability(a.device)
1997-
# If cuda version is 13 or greater AND cudnn version is 9.X or greater, prioritize cudnn.
1998-
if cuda_major >= 13: # to-do add cudnn version threshold
1996+
# If cuda version is 13 or greater:
1997+
# cudnn is more performant if cudnn version is 9.14 or greater.
1998+
if cuda_major >= 13 and cudnn.backend_version() >= 91400:
19991999
candidate_backends = ("cudnn", "cutlass")
20002000
# Otherwise, prioritize cutlass
20012001
else:
@@ -2026,11 +2026,7 @@ def mm_fp4(
20262026
supported_backends.append(candidate)
20272027
except Exception:
20282028
pass
2029-
print(f"Supported backends: {supported_backends}")
20302029
selected_backend = supported_backends[0]
2031-
print(
2032-
f"Selected backend: {selected_backend} for cuda version {cuda_major} and compute capability {cc_major}{cc_minor}"
2033-
)
20342030
else:
20352031
selected_backend = backend
20362032
if selected_backend == "cudnn":

0 commit comments

Comments
 (0)