1717import functools
1818from enum import Enum
1919from types import SimpleNamespace
20- from typing import List , Literal , Optional , Tuple
20+ from typing import List , Literal , Optional , Tuple , cast
2121
2222from flashinfer .trtllm_low_latency_gemm import trtllm_low_latency_gemm
2323import torch
@@ -1703,7 +1703,7 @@ def _check_mm_fp4_problem_size(
17031703 out : Optional [torch .Tensor ] = None ,
17041704 block_size : int = 16 ,
17051705 use_8x4_sf_layout : bool = False ,
1706- backend : Literal ["cudnn" , "trtllm" , "cutlass" ] = "cudnn " ,
1706+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto " ,
17071707 use_nvfp4 : bool = True ,
17081708):
17091709 # Generic checks
@@ -1743,8 +1743,8 @@ def _check_mm_fp4_problem_size(
17431743
17441744 if backend != "trtllm" and use_8x4_sf_layout :
17451745 raise ValueError ("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout." )
1746- if backend != "cudnn" and not use_nvfp4 :
1747- raise ValueError ("Only cudnn FP4 GEMM supports mxfp4 quantization." )
1746+ if backend not in [ "cudnn" , "auto" ] and not use_nvfp4 :
1747+ raise ValueError ("Only cudnn and auto FP4 GEMM supports mxfp4 quantization." )
17481748
17491749 if use_nvfp4 and block_size != 16 :
17501750 raise ValueError ("nvfp4 only supports block_size = 16." )
@@ -1765,7 +1765,7 @@ def _cudnn_gemm_fp4_requirement(
17651765 out : Optional [torch .Tensor ] = None ,
17661766 block_size : int = 16 ,
17671767 use_8x4_sf_layout : bool = False ,
1768- backend : Literal ["cudnn" , "trtllm" , "cutlass" ] = "cudnn " ,
1768+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto " ,
17691769 use_nvfp4 : bool = True ,
17701770):
17711771 if (
@@ -1823,7 +1823,7 @@ def _trtllm_gemm_fp4_requirement(
18231823 out : Optional [torch .Tensor ] = None ,
18241824 block_size : int = 16 ,
18251825 use_8x4_sf_layout : bool = False ,
1826- backend : Literal ["cudnn" , "trtllm" , "cutlass" ] = "cudnn " ,
1826+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto " ,
18271827 use_nvfp4 : bool = True ,
18281828):
18291829 if out_dtype != torch .bfloat16 :
@@ -1845,17 +1845,57 @@ def _cutlass_gemm_fp4_requirement(
18451845 out : Optional [torch .Tensor ] = None ,
18461846 block_size : int = 16 ,
18471847 use_8x4_sf_layout : bool = False ,
1848- backend : Literal ["cudnn" , "trtllm" , "cutlass" ] = "cudnn " ,
1848+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto " ,
18491849 use_nvfp4 : bool = True ,
18501850):
18511851 return True
18521852
18531853
1854+ @supported_compute_capability ([100 , 103 , 110 , 120 ])
1855+ def _auto_gemm_fp4_requirement (
1856+ a : torch .Tensor ,
1857+ b : torch .Tensor ,
1858+ a_descale : torch .Tensor ,
1859+ b_descale : torch .Tensor ,
1860+ alpha : Optional [torch .Tensor ] = None ,
1861+ out_dtype : torch .dtype = torch .bfloat16 ,
1862+ out : Optional [torch .Tensor ] = None ,
1863+ block_size : int = 16 ,
1864+ use_8x4_sf_layout : bool = False ,
1865+ backend : Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ] = "auto" ,
1866+ use_nvfp4 : bool = True ,
1867+ ):
1868+ # Auto backend requires at least one backend to be supported on the current device
1869+ cc_major , cc_minor = get_compute_capability (a .device )
1870+ cc_arch = cc_major * 10 + cc_minor
1871+
1872+ # Check if at least one backend is supported for this compute capability
1873+ candidate_backends = ["cudnn" , "cutlass" , "trtllm" ]
1874+ backend_checkers = {
1875+ "cudnn" : _cudnn_gemm_fp4_requirement ,
1876+ "cutlass" : _cutlass_gemm_fp4_requirement ,
1877+ # Does not consider trtllm due to different interface.
1878+ }
1879+
1880+ for candidate in candidate_backends :
1881+ checker = backend_checkers [candidate ]
1882+ if hasattr (
1883+ checker , "is_compute_capability_supported"
1884+ ) and checker .is_compute_capability_supported (cc_arch ):
1885+ # At least one backend is supported
1886+ print (f"Backend { candidate } is supported on this device." )
1887+ return True
1888+
1889+ # No backend is supported on this device
1890+ return False
1891+
1892+
18541893@backend_requirement (
18551894 {
18561895 "cudnn" : _cudnn_gemm_fp4_requirement , # Each backend has its own requirement function
18571896 "trtllm" : _trtllm_gemm_fp4_requirement ,
18581897 "cutlass" : _cutlass_gemm_fp4_requirement ,
1898+ "auto" : _auto_gemm_fp4_requirement , # Auto backend requires at least one backend to be supported on the current device
18591899 },
18601900 common_check = _check_mm_fp4_problem_size , # Shape checks common to all backends
18611901)
@@ -1950,22 +1990,40 @@ def mm_fp4(
19501990 if backend == "auto" :
19511991 cuda_major , _ = get_cuda_version (a .device )
19521992 cc_major , cc_minor = get_compute_capability (a .device )
1953- cc_arch = cc_major * 10 + cc_minor
19541993 # If cuda version is 13 or greater AND cudnn version is 9.X or greater, prioritize cudnn.
19551994 if cuda_major >= 13 : # to-do add cudnn version threshold
1956- candidate_backends = [ "cudnn" , "cutlass" ]
1995+ candidate_backends = ( "cudnn" , "cutlass" )
19571996 # Otherwise, prioritize cutlass
19581997 else :
1959- candidate_backends = ["cutlass" , "cudnn" ]
1960-
1961- # Support check
1962- backends_to_delete = []
1963- for candidate_backend in candidate_backends :
1964- if not mm_fp4 .is_backend_supported (candidate_backend , cc_arch ):
1965- backends_to_delete .append (candidate_backend )
1966- for backend_to_delete in backends_to_delete :
1967- candidate_backends .remove (backend_to_delete )
1968- selected_backend = candidate_backends [0 ]
1998+ candidate_backends = ("cutlass" , "cudnn" )
1999+
2000+ # Filter to only supported backends for this compute capability
2001+ # Note: The requirement function already validated that at least one backend is supported
2002+ supported_backends = []
2003+ for candidate in candidate_backends :
2004+ # mypy requires explicit type casting for the backend literal
2005+ backend_literal = cast (
2006+ Literal ["cudnn" , "trtllm" , "cutlass" , "auto" ], candidate
2007+ )
2008+ try :
2009+ _check_mm_fp4_problem_size (
2010+ a ,
2011+ b ,
2012+ a_descale ,
2013+ b_descale ,
2014+ alpha ,
2015+ out_dtype ,
2016+ out ,
2017+ block_size ,
2018+ use_8x4_sf_layout ,
2019+ backend_literal ,
2020+ use_nvfp4 ,
2021+ )
2022+ supported_backends .append (candidate )
2023+ except Exception :
2024+ pass
2025+ print (f"Supported backends: { supported_backends } " )
2026+ selected_backend = supported_backends [0 ]
19692027 print (
19702028 f"Selected backend: { selected_backend } for cuda version { cuda_major } and compute capability { cc_major } { cc_minor } "
19712029 )
0 commit comments