Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,10 @@ class ArtifactPath:
"1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
)
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
# For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py
DEEPGEMM: str = "a72d85b019dc125b9f711300cb989430f762f5a6/deep-gemm/"


@dataclass(frozen=True)
class MetaInfoHash:
DEEPGEMM: str = "f161e031826adb8c4f0d31ddbd2ed77e4909e4e43cdfc9728918162a62fcccfb"
TRTLLM_GEN_FMHA: str = (
"2b8a485f2af84768bc769e678eb6014a8181ad95a7ea9e699de5efca4b18ec6a"
)
TRTLLM_GEN_BMM: str = (
"26c51b75921be90235d193675facdea5d8341c4c52c73bd0a7c8e787c0388beb"
)
TRTLLM_GEN_GEMM: str = (
"bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340"
)


class CheckSumHash:
"""
This class is used to store the checksums of the cubin files in artifactory.
Expand Down
12 changes: 7 additions & 5 deletions flashinfer/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

import torch

from .artifacts import ArtifactPath, MetaInfoHash
from .artifacts import ArtifactPath
from .cuda_utils import checkCudaErrors
from .jit.cubin_loader import get_cubin
from .jit.env import FLASHINFER_CUBIN_DIR
Expand Down Expand Up @@ -1487,13 +1487,15 @@ def m_grouped_fp8_gemm_nt_masked(


class KernelMap:
def __init__(self, sha256: str):
self.sha256 = sha256
# Hash for kernel_map.json, updated when deepgemm cubins are republished
KERNEL_MAP_HASH = "f161e031826adb8c4f0d31ddbd2ed77e4909e4e43cdfc9728918162a62fcccfb"

def __init__(self):
self.indice = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and consistency with the filename kernel_map.json, consider renaming the instance variable indice to kernel_map. This would make it clearer what this variable holds. This change would need to be applied to all usages of self.indice within the KernelMap class.

Suggested change
self.indice = None
self.kernel_map = None


def init_indices(self):
indice_path = ArtifactPath.DEEPGEMM + "/" + "kernel_map.json"
assert get_cubin(indice_path, self.sha256), (
assert get_cubin(indice_path, self.KERNEL_MAP_HASH), (
"cubin kernel map file not found, nor downloaded with matched sha256"
)
path = FLASHINFER_CUBIN_DIR / indice_path
Expand All @@ -1513,4 +1515,4 @@ def __getitem__(self, key):
return self.indice[key]


KERNEL_MAP = KernelMap(MetaInfoHash.DEEPGEMM)
KERNEL_MAP = KernelMap()