Skip to content

Commit 969f547

Browse files
committed
Add second draft of mm_fp4 backend
1 parent 531967e commit 969f547

File tree

4 files changed

+105
-25
lines changed

4 files changed

+105
-25
lines changed

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,9 @@ def dtype_str_to_torch_dtype(dtype_str):
238238
"8.6": [],
239239
"8.9": [],
240240
"9.0": [],
241-
"10.0": ["cudnn", "trtllm", "cutlass"],
242-
"10.3": ["cudnn", "trtllm", "cutlass"],
243-
"12.0": ["cudnn", "cutlass"],
241+
"10.0": ["cudnn", "trtllm", "cutlass", "auto"],
242+
"10.3": ["cudnn", "trtllm", "cutlass", "auto"],
243+
"12.0": ["cudnn", "cutlass", "auto"],
244244
},
245245
# MOE
246246
"trtllm_fp4_block_scale_moe": {

benchmarks/routines/gemm.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def parse_gemm_args(line, parser):
131131
required=False,
132132
nargs="+",
133133
default=["cudnn"],
134-
choices=["cudnn", "cublas", "trtllm", "cutlass"],
134+
choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"],
135135
help="Kernel backends to test. Default: cudnn",
136136
)
137137
parser.add_argument(
@@ -823,7 +823,7 @@ def testMmFp4(args):
823823
print(
824824
"[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)"
825825
)
826-
backends.remove("cutlass")
826+
remove_cutlass = True
827827
if remove_cutlass:
828828
backends.remove("cutlass")
829829
if "cudnn" in backends:
@@ -833,6 +833,13 @@ def testMmFp4(args):
833833
remove_cudnn = True
834834
if remove_cudnn:
835835
backends.remove("cudnn")
836+
if "auto" in backends:
837+
remove_auto = False
838+
if not use_128x4_sf_layout:
839+
print("[INFO] auto backend does not support use_128x4_sf_layout=False")
840+
remove_auto = True
841+
if remove_auto:
842+
backends.remove("auto")
836843
if getattr(args, "autotune", False):
837844
backends_to_remove = []
838845
for cur_backend in backends:
@@ -889,7 +896,7 @@ def testMmFp4(args):
889896
# res = torch.empty([m, n], device="cuda", dtype=res_dtype)
890897

891898
def run_backend(backend):
892-
if backend in ["cudnn", "trtllm", "cutlass"]:
899+
if backend in ["cudnn", "trtllm", "cutlass", "auto"]:
893900
return flashinfer.gemm.mm_fp4(
894901
a=input_fp4,
895902
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,

flashinfer/gemm.py

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import functools
1818
from enum import Enum
1919
from types import SimpleNamespace
20-
from typing import List, Literal, Optional, Tuple
20+
from typing import List, Literal, Optional, Tuple, cast
2121

2222
from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm
2323
import torch
@@ -1703,7 +1703,7 @@ def _check_mm_fp4_problem_size(
17031703
out: Optional[torch.Tensor] = None,
17041704
block_size: int = 16,
17051705
use_8x4_sf_layout: bool = False,
1706-
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
1706+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
17071707
use_nvfp4: bool = True,
17081708
):
17091709
# Generic checks
@@ -1743,8 +1743,8 @@ def _check_mm_fp4_problem_size(
17431743

17441744
if backend != "trtllm" and use_8x4_sf_layout:
17451745
raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.")
1746-
if backend != "cudnn" and not use_nvfp4:
1747-
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")
1746+
if backend not in ["cudnn", "auto"] and not use_nvfp4:
1747+
raise ValueError("Only cudnn and auto FP4 GEMM supports mxfp4 quantization.")
17481748

17491749
if use_nvfp4 and block_size != 16:
17501750
raise ValueError("nvfp4 only supports block_size = 16.")
@@ -1765,7 +1765,7 @@ def _cudnn_gemm_fp4_requirement(
17651765
out: Optional[torch.Tensor] = None,
17661766
block_size: int = 16,
17671767
use_8x4_sf_layout: bool = False,
1768-
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
1768+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
17691769
use_nvfp4: bool = True,
17701770
):
17711771
if (
@@ -1823,7 +1823,7 @@ def _trtllm_gemm_fp4_requirement(
18231823
out: Optional[torch.Tensor] = None,
18241824
block_size: int = 16,
18251825
use_8x4_sf_layout: bool = False,
1826-
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
1826+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
18271827
use_nvfp4: bool = True,
18281828
):
18291829
if out_dtype != torch.bfloat16:
@@ -1845,17 +1845,57 @@ def _cutlass_gemm_fp4_requirement(
18451845
out: Optional[torch.Tensor] = None,
18461846
block_size: int = 16,
18471847
use_8x4_sf_layout: bool = False,
1848-
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
1848+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
18491849
use_nvfp4: bool = True,
18501850
):
18511851
return True
18521852

18531853

1854+
@supported_compute_capability([100, 103, 110, 120])
1855+
def _auto_gemm_fp4_requirement(
1856+
a: torch.Tensor,
1857+
b: torch.Tensor,
1858+
a_descale: torch.Tensor,
1859+
b_descale: torch.Tensor,
1860+
alpha: Optional[torch.Tensor] = None,
1861+
out_dtype: torch.dtype = torch.bfloat16,
1862+
out: Optional[torch.Tensor] = None,
1863+
block_size: int = 16,
1864+
use_8x4_sf_layout: bool = False,
1865+
backend: Literal["cudnn", "trtllm", "cutlass", "auto"] = "auto",
1866+
use_nvfp4: bool = True,
1867+
):
1868+
# Auto backend requires at least one backend to be supported on the current device
1869+
cc_major, cc_minor = get_compute_capability(a.device)
1870+
cc_arch = cc_major * 10 + cc_minor
1871+
1872+
# Check if at least one backend is supported for this compute capability
1873+
candidate_backends = ["cudnn", "cutlass", "trtllm"]
1874+
backend_checkers = {
1875+
"cudnn": _cudnn_gemm_fp4_requirement,
1876+
"cutlass": _cutlass_gemm_fp4_requirement,
1877+
# Does not consider trtllm due to different interface.
1878+
}
1879+
1880+
for candidate in candidate_backends:
1881+
checker = backend_checkers[candidate]
1882+
if hasattr(
1883+
checker, "is_compute_capability_supported"
1884+
) and checker.is_compute_capability_supported(cc_arch):
1885+
# At least one backend is supported
1886+
print(f"Backend {candidate} is supported on this device.")
1887+
return True
1888+
1889+
# No backend is supported on this device
1890+
return False
1891+
1892+
18541893
@backend_requirement(
18551894
{
18561895
"cudnn": _cudnn_gemm_fp4_requirement, # Each backend has its own requirement function
18571896
"trtllm": _trtllm_gemm_fp4_requirement,
18581897
"cutlass": _cutlass_gemm_fp4_requirement,
1898+
"auto": _auto_gemm_fp4_requirement, # Auto backend requires at least one backend to be supported on the current device
18591899
},
18601900
common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends
18611901
)
@@ -1950,22 +1990,40 @@ def mm_fp4(
19501990
if backend == "auto":
19511991
cuda_major, _ = get_cuda_version(a.device)
19521992
cc_major, cc_minor = get_compute_capability(a.device)
1953-
cc_arch = cc_major * 10 + cc_minor
19541993
# If cuda version is 13 or greater AND cudnn version is 9.X or greater, prioritize cudnn.
19551994
if cuda_major >= 13: # to-do add cudnn version threshold
1956-
candidate_backends = ["cudnn", "cutlass"]
1995+
candidate_backends = ("cudnn", "cutlass")
19571996
# Otherwise, prioritize cutlass
19581997
else:
1959-
candidate_backends = ["cutlass", "cudnn"]
1960-
1961-
# Support check
1962-
backends_to_delete = []
1963-
for candidate_backend in candidate_backends:
1964-
if not mm_fp4.is_backend_supported(candidate_backend, cc_arch):
1965-
backends_to_delete.append(candidate_backend)
1966-
for backend_to_delete in backends_to_delete:
1967-
candidate_backends.remove(backend_to_delete)
1968-
selected_backend = candidate_backends[0]
1998+
candidate_backends = ("cutlass", "cudnn")
1999+
2000+
# Filter to only supported backends for this compute capability
2001+
# Note: The requirement function already validated that at least one backend is supported
2002+
supported_backends = []
2003+
for candidate in candidate_backends:
2004+
# mypy requires explicit type casting for the backend literal
2005+
backend_literal = cast(
2006+
Literal["cudnn", "trtllm", "cutlass", "auto"], candidate
2007+
)
2008+
try:
2009+
_check_mm_fp4_problem_size(
2010+
a,
2011+
b,
2012+
a_descale,
2013+
b_descale,
2014+
alpha,
2015+
out_dtype,
2016+
out,
2017+
block_size,
2018+
use_8x4_sf_layout,
2019+
backend_literal,
2020+
use_nvfp4,
2021+
)
2022+
supported_backends.append(candidate)
2023+
except Exception:
2024+
pass
2025+
print(f"Supported backends: {supported_backends}")
2026+
selected_backend = supported_backends[0]
19692027
print(
19702028
f"Selected backend: {selected_backend} for cuda version {cuda_major} and compute capability {cc_major}{cc_minor}"
19712029
)

tests/gemm/test_mm_fp4.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,20 @@ def test_mm_fp4(
105105
pytest.fail(str(e))
106106

107107

108+
# Split tests for checking auto functionality
109+
@pytest.mark.parametrize("m", [1, 48, 256, 512])
110+
@pytest.mark.parametrize("n", [256, 512])
111+
@pytest.mark.parametrize("k", [256, 512])
112+
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
113+
@pytest.mark.parametrize("backend", ["auto"])
114+
@pytest.mark.parametrize("use_128x4_sf_layout", [False, True])
115+
@pytest.mark.parametrize("auto_tuning", [False, True])
116+
@pytest.mark.parametrize("fp4_type", ["nvfp4", "mxfp4", "mxfp4_alpha"])
117+
def test_mm_fp4_backend_auto(
118+
m, n, k, res_dtype, backend, use_128x4_sf_layout, auto_tuning, fp4_type
119+
):
120+
test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type)
121+
122+
108123
if __name__ == "__main__":
109124
pytest.main([__file__])

0 commit comments

Comments
 (0)