diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 520a3e1c6f..1c041be9d0 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -95,23 +95,10 @@ class ArtifactPath: "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" ) CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/" + # For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/" -@dataclass(frozen=True) -class MetaInfoHash: - DEEPGEMM: str = "f161e031826adb8c4f0d31ddbd2ed77e4909e4e43cdfc9728918162a62fcccfb" - TRTLLM_GEN_FMHA: str = ( - "2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a" - ) - TRTLLM_GEN_BMM: str = ( - "26c51b75921be90235d193675facdea5d8341c4c52c73bd0a7c8e787c0388beb" - ) - TRTLLM_GEN_GEMM: str = ( - "bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340" - ) - - class CheckSumHash: """ This class is used to store the checksums of the cubin files in artifactory. diff --git a/flashinfer/deep_gemm.py b/flashinfer/deep_gemm.py index 4da91750fd..c7e42494d4 100644 --- a/flashinfer/deep_gemm.py +++ b/flashinfer/deep_gemm.py @@ -41,7 +41,7 @@ import torch -from .artifacts import ArtifactPath, MetaInfoHash +from .artifacts import ArtifactPath from .cuda_utils import checkCudaErrors from .jit.cubin_loader import get_cubin from .jit.env import FLASHINFER_CUBIN_DIR @@ -1487,13 +1487,15 @@ def m_grouped_fp8_gemm_nt_masked( class KernelMap: - def __init__(self, sha256: str): - self.sha256 = sha256 + # Hash for kernel_map.json, updated when deepgemm cubins are republished + KERNEL_MAP_HASH = "f161e031826adb8c4f0d31ddbd2ed77e4909e4e43cdfc9728918162a62fcccfb" + + def __init__(self): self.indice = None def init_indices(self): indice_path = ArtifactPath.DEEPGEMM + "/" + "kernel_map.json" - assert get_cubin(indice_path, self.sha256), ( + assert get_cubin(indice_path, self.KERNEL_MAP_HASH), ( "cubin kernel map file not found, nor downloaded with matched sha256" ) path = FLASHINFER_CUBIN_DIR / indice_path @@ -1513,4 +1515,4 @@ def __getitem__(self, key): return self.indice[key] -KERNEL_MAP = KernelMap(MetaInfoHash.DEEPGEMM) +KERNEL_MAP = KernelMap()