From 737dbba82f49f9f6c2d3a18a6a69d342cfca02e1 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 30 Nov 2025 17:49:25 -0500 Subject: [PATCH 1/2] update Signed-off-by: Ubospica --- examples/win_at_p.py | 1 - flashinfer_bench/apply/key.py | 2 +- flashinfer_bench/compile/__init__.py | 11 +- flashinfer_bench/compile/builder.py | 80 +----- flashinfer_bench/compile/builders/__init__.py | 4 +- .../compile/builders/cuda_builder.py | 239 ------------------ .../compile/builders/python_builder.py | 121 +++++---- .../compile/builders/torch_builder.py | 208 +++++++++++++++ .../compile/builders/triton_builder.py | 57 ++--- .../compile/builders/tvm_ffi_builder.py | 217 +++++----------- flashinfer_bench/compile/registry.py | 64 +++-- flashinfer_bench/compile/runnable.py | 145 ++++++++--- flashinfer_bench/compile/utils.py | 68 +++++ flashinfer_bench/data/solution.py | 49 +++- flashinfer_bench/data/trace_set.py | 1 - tests/apply/test_runtime.py | 6 +- tests/compile/test_builder.py | 12 +- tests/compile/test_python_builder.py | 7 +- tests/compile/test_runnable.py | 14 +- tests/compile/test_triton_builder.py | 6 +- tests/compile/test_tvm_ffi_builder.py | 20 +- tests/data/test_solution.py | 2 - 22 files changed, 663 insertions(+), 671 deletions(-) delete mode 100644 flashinfer_bench/compile/builders/cuda_builder.py create mode 100644 flashinfer_bench/compile/builders/torch_builder.py create mode 100644 flashinfer_bench/compile/utils.py diff --git a/examples/win_at_p.py b/examples/win_at_p.py index a65fda87..474153f1 100644 --- a/examples/win_at_p.py +++ b/examples/win_at_p.py @@ -26,7 +26,6 @@ Notes: - If multiple runs exist for the same author within a group, we take the MIN latency for that author in that group. - By default, the baseline author ('flashinfer') is EXCLUDED from output curves; use --include-baseline to include it. - """ import argparse diff --git a/flashinfer_bench/apply/key.py b/flashinfer_bench/apply/key.py index 85a53325..4ed8f569 100644 --- a/flashinfer_bench/apply/key.py +++ b/flashinfer_bench/apply/key.py @@ -12,7 +12,7 @@ class ApplyKey: axes: Tuple[Tuple[str, int], ...] = field(default_factory=tuple) # Features extracted from input tensors - feats: Tuple[Tuple[str, Union[int, Union[float, bool]]], ...] = field(default_factory=tuple) + feats: Tuple[Tuple[str, Union[int, float, bool]], ...] = field(default_factory=tuple) def encode(self) -> str: return json.dumps( diff --git a/flashinfer_bench/compile/__init__.py b/flashinfer_bench/compile/__init__.py index 03284bc8..1e0a4801 100644 --- a/flashinfer_bench/compile/__init__.py +++ b/flashinfer_bench/compile/__init__.py @@ -5,6 +5,13 @@ from .builder import Builder, BuildError from .registry import BuilderRegistry, get_builder_registry -from .runnable import Runnable +from .runnable import Runnable, RunnableMetadata -__all__ = ["Builder", "BuildError", "BuilderRegistry", "Runnable", "get_builder_registry"] +__all__ = [ + "Builder", + "BuildError", + "BuilderRegistry", + "Runnable", + "RunnableMetadata", + "get_builder_registry", +] diff --git a/flashinfer_bench/compile/builder.py b/flashinfer_bench/compile/builder.py index 5f4ebc40..37cc967c 100644 --- a/flashinfer_bench/compile/builder.py +++ b/flashinfer_bench/compile/builder.py @@ -1,48 +1,11 @@ from __future__ import annotations -import hashlib -import os -import re -import tempfile from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional +from typing import Dict -from flashinfer_bench.compile.runnable import Runnable -from flashinfer_bench.data import Definition, Solution, SourceFile +from flashinfer_bench.data import Definition, Solution - -def write_sources_to_dir(dir: str, sources: list[SourceFile]) -> None: - os.makedirs(dir, exist_ok=True) - for src in sources: - abspath = os.path.join(dir, src.path) - os.makedirs(os.path.dirname(abspath), exist_ok=True) - with open(abspath, "w", encoding="utf-8") as f: - f.write(src.content) - - -def write_sources_to_temp(base: str, sources: list[SourceFile], pkg: Optional[str] = None) -> str: - os.makedirs(base, exist_ok=True) - tmpdir = tempfile.mkdtemp(dir=base) - if pkg: - tmpdir = os.path.join(tmpdir, pkg) - os.makedirs(tmpdir, exist_ok=True) - write_sources_to_dir(tmpdir, sources) - return tmpdir - - -def create_pkg_name(sol: Solution, prefix: str = "") -> str: - # Normalize the solution name - s = re.sub(r"[^0-9a-zA-Z_]", "_", sol.name) - if not s or s[0].isdigit(): - s = "_" + s - - # Hash the sources - h = hashlib.sha1() - for src in sol.sources: - h.update(src.path.encode()) - h.update(src.content.encode()) - - return prefix + s + "_" + h.hexdigest()[:6] +from .runnable import Runnable class BuildError(RuntimeError): @@ -52,8 +15,10 @@ class BuildError(RuntimeError): class Builder(ABC): """Builder abstraction: (Definition, Solution) -> Runnable with hidden cache.""" + @abstractmethod def __init__(self) -> None: - self._cache: Dict[str, Runnable] = {} + """Constructor for the Builder class.""" + ... @abstractmethod def can_build(self, solution: Solution) -> bool: @@ -61,35 +26,12 @@ def can_build(self, solution: Solution) -> bool: ... @abstractmethod - def _build(self, definition: Definition, solution: Solution) -> Runnable: - """Perform a real build and return a Runnable; raise BuildError on failure.""" + def get_key(self, solution: Solution) -> str: + """Get a unique key for a solution. This key is used to cache the build results and + used as the module name.""" ... @abstractmethod - def _make_closer(self, *args, **kwargs) -> Callable[[], None]: - """Factory for a resource closer used by the concrete builder.""" - ... - - @abstractmethod - def _make_key(self, solution: Solution) -> str: - """Cache key for a solution.""" - ... - def build(self, definition: Definition, solution: Solution) -> Runnable: - """Public entry with per-solution cache keyed by solution.name.""" - key = self._make_key(solution) - if key in self._cache: - return self._cache[key] - runnable = self._build(definition, solution) - self._cache[key] = runnable - return runnable - - def clear_cache(self) -> None: - """Close all cached runnables and clear the cache.""" - for r in list(self._cache.values()): - try: - r.close() - except Exception: - # Best-effort cleanup; keep going - pass - self._cache.clear() + """Perform a real build and return a Runnable; raise BuildError on failure.""" + ... diff --git a/flashinfer_bench/compile/builders/__init__.py b/flashinfer_bench/compile/builders/__init__.py index bd4cd21b..5932dffb 100644 --- a/flashinfer_bench/compile/builders/__init__.py +++ b/flashinfer_bench/compile/builders/__init__.py @@ -1,6 +1,6 @@ -from .cuda_builder import CUDABuilder from .python_builder import PythonBuilder +from .torch_builder import TorchBuilder from .triton_builder import TritonBuilder from .tvm_ffi_builder import TVMFFIBuilder -__all__ = ["CUDABuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"] +__all__ = ["TorchBuilder", "PythonBuilder", "TritonBuilder", "TVMFFIBuilder"] diff --git a/flashinfer_bench/compile/builders/cuda_builder.py b/flashinfer_bench/compile/builders/cuda_builder.py deleted file mode 100644 index 84bc0dab..00000000 --- a/flashinfer_bench/compile/builders/cuda_builder.py +++ /dev/null @@ -1,239 +0,0 @@ -from __future__ import annotations - -import logging -import os -import re -import shutil -import sys -from importlib import resources -from pathlib import Path -from typing import Dict, List, Optional - -from flashinfer_bench.compile.builder import ( - Builder, - BuildError, - create_pkg_name, - write_sources_to_dir, -) -from flashinfer_bench.compile.runnable import Runnable -from flashinfer_bench.data import Definition, Solution, SourceFile, SupportedLanguages -from flashinfer_bench.utils import is_cuda_available - -CUDA_ALLOWED_EXTS = [".cu", ".cpp", ".cc", ".cxx", ".c"] - -logger = logging.getLogger(__name__) - - -def _get_package_paths(pkg_name: str, lib_names: Optional[List[str]] = None): - include_path = None - ldflags = [] - - try: - include_dir = resources.files(pkg_name) / "include" - if include_dir.exists(): - include_path = str(include_dir) - - if lib_names: - lib_dir = resources.files(pkg_name) / "lib" - if lib_dir.exists(): - lib_path = Path(lib_dir) - - if sys.platform.startswith("linux"): - ldflags = [f"-L{lib_path}", f"-Wl,-rpath,{lib_path}"] - - for lib_name in lib_names: - # Look for unversioned .so first - lib_file = lib_path / f"lib{lib_name}.so" - if lib_file.exists(): - ldflags.append(f"-l{lib_name}") - else: - # Find versioned .so files - versioned = sorted(lib_path.glob(f"lib{lib_name}.so.*")) - if versioned: - ldflags.append(f"-l:{versioned[-1].name}") - else: - ldflags.append(f"-l{lib_name}") # Fallback - - elif sys.platform == "win32": - ldflags = [f"/LIBPATH:{lib_path}"] + lib_names - - except Exception: - logger.warning( - "Failed to discover resources for CUDA package '%s'; continuing without it.", - pkg_name, - exc_info=True, - ) - - return include_path, ldflags - - -CUDA_DEPS = { - "cublas": ("nvidia.cublas", ["cublas", "cublasLt"]), - "cudnn": ("nvidia.cudnn", ["cudnn"]), - "cutlass": ("flashinfer_bench._deps.cutlass", None), # Header-only -} - - -def _discover_cuda_deps(extra_include_paths: Dict[str, str], extra_ldflags: Dict[str, List[str]]): - for dep_name, (pkg_name, libs) in CUDA_DEPS.items(): - include_path, ldflags = _get_package_paths(pkg_name, libs) - if include_path: - extra_include_paths[dep_name] = include_path - if ldflags: - extra_ldflags[dep_name] = ldflags - - -CUDA_DEPS_INCLUDE_PATTERNS = { - "cublas": re.compile( - r'^\s*#\s*include\s*[<"]\s*(?:cublas|cublasLt)', re.MULTILINE | re.IGNORECASE - ), - "cudnn": re.compile(r'^\s*#\s*include\s*[<"]\s*cudnn', re.MULTILINE | re.IGNORECASE), - "cutlass": re.compile(r'^\s*#\s*include\s*[<"]\s*cutlass/', re.MULTILINE), -} - - -def _check_dependency(sources: List[SourceFile], dep_name: str) -> bool: - pattern = CUDA_DEPS_INCLUDE_PATTERNS.get(dep_name) - if not pattern: - return False - - for source in sources: - if not isinstance(source.content, str): - continue - - # Fast skip - if dep_name not in source.content.lower(): - continue - - # Remove comments - content = source.content - content = re.sub(r"//.*?$", "", content, flags=re.MULTILINE) - content = re.sub(r"/\*.*?\*/", "", content, flags=re.DOTALL) - - if pattern.search(content): - return True - - return False - - -class CUDABuilder(Builder): - _cuda_available: bool = None - - @classmethod - def _get_cuda_available(cls) -> bool: - if cls._cuda_available is None: - cls._cuda_available = is_cuda_available() - return cls._cuda_available - - def __init__(self) -> None: - super().__init__() - self._build_dirs: Dict[str, str] = {} - self._extra_include_paths: Dict[str, str] = {} - self._extra_ldflags: Dict[str, List[str]] = {} - _discover_cuda_deps(self._extra_include_paths, self._extra_ldflags) - - def can_build(self, sol: Solution) -> bool: - return sol.spec.language == SupportedLanguages.CUDA and self._get_cuda_available() - - def _make_key(self, solution: Solution) -> str: - return f"cuda::{create_pkg_name(solution)}" - - def _make_closer(self): - # We keep build dirs for torch extension caching. The temp dirs can be cleaned by - # calling `clear_cache` on program exit. - return lambda: None - - def _build(self, defn: Definition, sol: Solution) -> Runnable: - # CUDA solutions must provide a C/CUDA symbol as entry point. - # If user prefer a Python wrapper, set language to `python` and ensure compilation and - # binding are properly handled. - entry_file_extension = "." + sol.spec.entry_point.split("::")[0].split(".")[-1] - if entry_file_extension not in CUDA_ALLOWED_EXTS: - raise BuildError( - f"Entry file type not recognized. Must be one of {CUDA_ALLOWED_EXTS}, " - f"got {entry_file_extension}." - ) - - if not self._get_cuda_available(): - raise BuildError("torch.cuda is not available in the current environment") - - from torch.utils.cpp_extension import load - - symbol = sol.spec.entry_point.split("::")[-1] - name = create_pkg_name(sol, "fib_cuda_") - cache_root = os.environ.get( - "FIB_CACHE_PATH", os.path.join(os.path.expanduser("~"), ".cache", "flashinfer_bench") - ) - build_dir = os.path.join(cache_root, "cuda", name) - write_sources_to_dir(build_dir, sol.sources) - self._build_dirs[name] = build_dir - - sources = [s for s in sol.sources if s.path.endswith(tuple(CUDA_ALLOWED_EXTS))] - - has_cuda_sources = any(s.path.endswith(".cu") for s in sources) - if not has_cuda_sources: - raise BuildError("No CUDA sources provided for CUDA build") - - src_paths = [os.path.join(build_dir, s.path) for s in sources] - - extra_include_paths = [build_dir] - extra_ldflags = [] - - for dep in CUDA_DEPS.keys(): - if _check_dependency(sources, dep): - inc_path = self._extra_include_paths.get(dep) - if not inc_path: - raise BuildError( - f"{dep} is not available in the current environment but referenced " - f"by {sol.name}" - ) - extra_include_paths.append(inc_path) - ldflags = self._extra_ldflags.get(dep) - if ldflags: - extra_ldflags.extend(ldflags) - - closer = self._make_closer() - - try: - ext = load( - name=name, - sources=src_paths, - extra_include_paths=extra_include_paths, - extra_ldflags=extra_ldflags, - with_cuda=True, - build_directory=build_dir, - verbose=True, - ) - except Exception as e: - raise BuildError(f"CUDA build failed for solution '{sol.name}': {e}") from e - - try: - fn = getattr(ext, symbol) - except AttributeError as e: - raise BuildError(f"Exported symbol '{symbol}' not found in built extension") from e - - arg_order = list(defn.inputs.keys()) - - def _kw_adapter(**kwargs): - args = [kwargs[name] for name in arg_order] - return fn(*args) - - meta = { - "definition": defn.name, - "solution": sol.name, - "language": "cuda", - "name": name, - "entry": sol.spec.entry_point, - "symbol": symbol, - "build_dir": build_dir, - "binary": getattr(ext, "__file__", None), - "extra_include_paths": extra_include_paths, - "extra_ldflags": extra_ldflags, - } - return Runnable(fn=_kw_adapter, closer=closer, meta=meta) - - def clear_cache(self) -> None: - super().clear_cache() - for build_dir in self._build_dirs.values(): - shutil.rmtree(build_dir, ignore_errors=True) - self._build_dirs.clear() diff --git a/flashinfer_bench/compile/builders/python_builder.py b/flashinfer_bench/compile/builders/python_builder.py index 86e4e83e..f845b52e 100644 --- a/flashinfer_bench/compile/builders/python_builder.py +++ b/flashinfer_bench/compile/builders/python_builder.py @@ -5,100 +5,95 @@ import shutil import sys from pathlib import Path -from typing import Any, Callable - -from flashinfer_bench.compile.builder import ( - Builder, - BuildError, - create_pkg_name, - write_sources_to_temp, -) -from flashinfer_bench.compile.runnable import Runnable +from typing import Any, Callable, ClassVar + +from flashinfer_bench.compile.builder import Builder, BuildError +from flashinfer_bench.compile.runnable import Runnable, RunnableMetadata +from flashinfer_bench.compile.utils import create_package_name, write_sources_to_path from flashinfer_bench.data import Definition, Solution, SupportedLanguages +from flashinfer_bench.env import get_fib_cache_path class PythonBuilder(Builder): """Load a Python entry point from provided sources into a temporary module.""" + _BUILD_DIR_NAME: ClassVar[str] = "python" + """Subdirectory under FIB_CACHE_PATH where build results are stored.""" + + _KEY_PREFIX: ClassVar[str] = "fib_python_" + """Prefix for cache keys to avoid collisions with other builders. fib_ prefix is added + to avoid name collision in python imports.""" + def can_build(self, sol: Solution) -> bool: return sol.spec.language == SupportedLanguages.PYTHON - def _make_key(self, solution: Solution) -> str: - return f"python::{create_pkg_name(solution)}" + def get_key(self, solution: Solution) -> str: + return create_package_name(solution, self._KEY_PREFIX) + + def _get_build_path(self, key: str) -> Path: + return get_fib_cache_path() / self._BUILD_DIR_NAME / key - def _make_closer(self, pkg: str, tmpdir: str) -> Callable[[], None]: - def closer() -> None: + def _get_cleaner(self, package: str, build_path: Path) -> Callable[[], None]: + def cleaner() -> None: try: # Unload module and submodules - to_delete = [m for m in list(sys.modules) if m == pkg or m.startswith(pkg + ".")] + to_delete = [m for m in sys.modules if m == package or m.startswith(package + ".")] for m in to_delete: sys.modules.pop(m, None) except Exception: pass + try: - while tmpdir in sys.path: - try: - sys.path.remove(tmpdir) - except ValueError: - break + build_path_str = str(build_path) + if build_path_str in sys.path: + sys.path.remove(build_path_str) finally: - shutil.rmtree(tmpdir, ignore_errors=True) + shutil.rmtree(build_path, ignore_errors=True) - return closer + return cleaner - def _build(self, defn: Definition, sol: Solution) -> Runnable: - entry = sol.spec.entry_point - try: - entry_file, entry_func = entry.split("::", 1) - except ValueError as e: - raise BuildError("entry_point must be '::' for Python") from e - - # _fib_py_some_solution_ - pkg = create_pkg_name(sol, "fib_py_") - # _fib_py_some_solution_.entry_file - module_name = pkg + "." + ".".join(Path(entry_file).with_suffix("").parts) - # $HOME/.cache/flashinfer_bench/python// - cache_root = os.environ.get( - "FIB_CACHE_PATH", os.path.join(os.path.expanduser("~"), ".cache", "flashinfer_bench") - ) - pkg_dir = write_sources_to_temp( - base=os.path.join(cache_root, "python"), sources=sol.sources, pkg=pkg - ) - tmp_root = os.path.dirname(pkg_dir) - closer = self._make_closer(pkg, tmp_root) + def build(self, definition: Definition, solution: Solution) -> Runnable: + entry_file = solution.get_entry_path() + entry_symbol = solution.get_entry_symbol() - # Insert tmp_root into sys.path for import resolution - sys.path.insert(0, tmp_root) + if entry_file.suffix != ".py": + raise BuildError(f"Entry file '{entry_file}' is not a Python file") - if not os.path.exists(os.path.join(pkg_dir, *Path(entry_file).parts)): - closer() - raise BuildError(f"Entry file '{entry_file}' not found under tmp_root: {tmp_root}") + # fib_python_some_solution_ + package_name = self.get_key(solution) + # fib_python_some_solution_.entry_file + module_name = package_name + "." + ".".join(Path(entry_file).with_suffix("").parts) + + build_path = self._get_build_path(package_name) + write_sources_to_path(build_path, solution.sources) + cleaner = self._get_cleaner(package_name, build_path) + + # Insert tmp_root into sys.path for import resolution + sys.path.insert(0, str(build_path)) try: mod = importlib.import_module(module_name) except Exception as e: - closer() + cleaner() raise BuildError(f"Failed importing module '{module_name}' from sources: {e}") from e try: - fn: Any = getattr(mod, entry_func) + fn: Any = getattr(mod, entry_symbol) except AttributeError as e: - closer() + cleaner() raise BuildError( - f"Entry function '{entry_func}' not found in module '{module_name}'" + f"Entry symbol '{entry_symbol}' not found in module '{module_name}'" ) from e if not callable(fn): - closer() - raise BuildError(f"Entry '{entry_func}' is not callable") - - meta = { - "definition": defn.name, - "solution": sol.name, - "language": "python", - "module": module_name, - "entry": entry, - "temp_dir": tmp_root, - } - - return Runnable(fn=fn, closer=closer, meta=meta) + cleaner() + raise BuildError(f"Entry symbol '{entry_symbol}' is not callable") + + metadata = RunnableMetadata( + build_type="python", + definition=definition.name, + solution=solution.name, + misc={"module": module_name, "entry_symbol": entry_symbol}, + ) + + return Runnable(callable=fn, metadata=metadata, cleaner=cleaner) diff --git a/flashinfer_bench/compile/builders/torch_builder.py b/flashinfer_bench/compile/builders/torch_builder.py new file mode 100644 index 00000000..4d8ce696 --- /dev/null +++ b/flashinfer_bench/compile/builders/torch_builder.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import logging +import shutil +import sys +from importlib import resources +from pathlib import Path +from typing import Callable, ClassVar, Dict, List, Optional, Tuple + +from flashinfer_bench.compile import Runnable, RunnableMetadata +from flashinfer_bench.compile.builder import Builder, BuildError +from flashinfer_bench.compile.utils import create_package_name, write_sources_to_path +from flashinfer_bench.data import Definition, Solution, SourceFile, SupportedLanguages +from flashinfer_bench.env import get_fib_cache_path + +logger = logging.getLogger(__name__) + +# C/C++ and CUDA source file extensions +_CPP_CUDA_EXTENSIONS: List[str] = [".cu", ".cpp", ".cc", ".cxx", ".c"] + +# CUDA dependencies and their package names and library names +_CUDA_DEPS: Dict[str, Tuple[str, Optional[List[str]]]] = { + "cublas": ("nvidia.cublas", ["cublas", "cublasLt"]), + "cudnn": ("nvidia.cudnn", ["cudnn"]), + "cutlass": ("flashinfer_bench._deps.cutlass", None), # Header-only dependency +} + + +def _get_package_paths(pkg_name: str, lib_names: Optional[List[str]] = None): + include_path = None + ldflags = [] + + try: + include_dir = resources.files(pkg_name) / "include" + if include_dir.exists(): + include_path = str(include_dir) + + if lib_names: + lib_dir = resources.files(pkg_name) / "lib" + if lib_dir.exists(): + lib_path = Path(lib_dir) + + if sys.platform.startswith("linux"): + ldflags = [f"-L{lib_path}", f"-Wl,-rpath,{lib_path}"] + + for lib_name in lib_names: + # Look for unversioned .so first + lib_file = lib_path / f"lib{lib_name}.so" + if lib_file.exists(): + ldflags.append(f"-l{lib_name}") + else: + # Find versioned .so files + versioned = sorted(lib_path.glob(f"lib{lib_name}.so.*")) + if versioned: + ldflags.append(f"-l:{versioned[-1].name}") + else: + ldflags.append(f"-l{lib_name}") # Fallback + + elif sys.platform == "win32": + ldflags = [f"/LIBPATH:{lib_path}"] + lib_names + + except Exception: + logger.warning( + "Failed to discover resources for CUDA package '%s'; continuing without it.", + pkg_name, + exc_info=True, + ) + + return include_path, ldflags + + +class TorchBuilder(Builder): + _BUILD_DIR_NAME: ClassVar[str] = "torch" + """Subdirectory under FIB_CACHE_PATH where build results are stored""" + + _KEY_PREFIX: ClassVar[str] = "torch_" + """Prefix for cache keys to avoid collisions with other builders""" + + _extra_include_paths: Dict[str, str] + """Extra include paths for CUDA dependencies""" + _extra_ldflags: Dict[str, List[str]] + """Extra link flags for CUDA dependencies""" + + def __init__(self) -> None: + super().__init__() + self._discover_cuda_deps() + + def _discover_cuda_deps(self): + self._extra_include_paths = {} + self._extra_ldflags = {} + for dep_name, (pkg_name, libs) in _CUDA_DEPS.items(): + include_path, ldflags = _get_package_paths(pkg_name, libs) + if include_path: + self._extra_include_paths[dep_name] = include_path + if ldflags: + self._extra_ldflags[dep_name] = ldflags + + @staticmethod + def is_available() -> bool: + """Check if CUDA is available in the current environment.""" + try: + import torch + except ImportError: + return False + return torch.cuda.is_available() + + def can_build(self, sol: Solution) -> bool: + return sol.spec.language == SupportedLanguages.CUDA + + def _get_build_path(self, key: str) -> Path: + """Get the build directory path for a given cache key. + + Parameters + ---------- + key : str + Unique cache key for the solution + + Returns + ------- + Path + Directory path where build results will be stored + """ + return get_fib_cache_path() / self._BUILD_DIR_NAME / key + + def get_key(self, solution: Solution) -> str: + return create_package_name(solution, self._KEY_PREFIX) + + def _filter_sources(self, source_paths: List[Path]) -> List[str]: + return [str(path) for path in source_paths if path.suffix in _CPP_CUDA_EXTENSIONS] + + def _get_dependency_flags(self, sol: Solution) -> Tuple[List[str], List[str]]: + extra_include_paths = [] + extra_ldflags = [] + + for dep in sol.spec.dependencies: + if dep not in _CUDA_DEPS.keys(): + logger.warning(f"Dependency '{dep}' not found in CUDA_DEPS") + continue + inc_path = self._extra_include_paths.get(dep) + if not inc_path: + raise BuildError( + f"{dep} is not available in the current environment but referenced " + f"by {sol.name}" + ) + extra_include_paths.append(inc_path) + ldflags = self._extra_ldflags.get(dep) + if ldflags: + extra_ldflags.extend(ldflags) + + return extra_include_paths, extra_ldflags + + def _get_cleaner(self, build_dir: Path) -> Callable[[], None]: + """Get a cleaner function for the build directory.""" + + def cleaner() -> None: + shutil.rmtree(build_dir, ignore_errors=True) + + return cleaner + + def build(self, definition: Definition, solution: Solution) -> Runnable: + from torch.utils.cpp_extension import load + + entry_file_extension = solution.get_entry_path().suffix + if entry_file_extension not in _CPP_CUDA_EXTENSIONS: + raise BuildError( + f"Entry file type not recognized. Must be one of {_CPP_CUDA_EXTENSIONS}, " + f"got {entry_file_extension}." + ) + + symbol = solution.get_entry_symbol() + key = self.get_key(solution) + build_dir = self._get_build_path(key) + src_paths = write_sources_to_path(build_dir, solution.sources) + + src_paths = self._filter_sources(src_paths) + + extra_include_paths, extra_ldflags = self._get_dependency_flags(solution) + # Add build directory to include paths + extra_include_paths.append(build_dir) + + try: + ext = load( + name=key, + sources=src_paths, + extra_include_paths=extra_include_paths, + extra_ldflags=extra_ldflags, + with_cuda=True, + build_directory=build_dir, + verbose=True, + ) + except Exception as e: + raise BuildError(f"CUDA build failed for solution '{solution.name}': {e}") from e + + try: + callable = getattr(ext, symbol) + except AttributeError as e: + raise BuildError(f"Exported symbol '{symbol}' not found in built extension") from e + + metadata = RunnableMetadata( + build_type="torch", + definition=definition.name, + solution=solution.name, + misc={"entry_symbol": symbol, "binary": getattr(ext, "__file__", None)}, + ) + + cleaner = self._get_cleaner(build_dir) + + return Runnable(callable=callable, metadata=metadata, cleaner=cleaner) diff --git a/flashinfer_bench/compile/builders/triton_builder.py b/flashinfer_bench/compile/builders/triton_builder.py index 7ed81160..54ea3693 100644 --- a/flashinfer_bench/compile/builders/triton_builder.py +++ b/flashinfer_bench/compile/builders/triton_builder.py @@ -1,51 +1,32 @@ from __future__ import annotations -from flashinfer_bench.compile.builder import Builder, BuildError, create_pkg_name +from typing import ClassVar + from flashinfer_bench.compile.runnable import Runnable +from flashinfer_bench.compile.utils import create_package_name from flashinfer_bench.data import Definition, Solution, SupportedLanguages from .python_builder import PythonBuilder -def _verify_triton() -> bool: - try: - import triton - except Exception: - return False - return True - - -class TritonBuilder(Builder): - _triton_available: bool = None - - @classmethod - def _get_triton_available(cls) -> bool: - if cls._triton_available is None: - cls._triton_available = _verify_triton() - return cls._triton_available +class TritonBuilder(PythonBuilder): + _KEY_PREFIX: ClassVar[str] = "fib_triton_" - def __init__(self, py_builder: PythonBuilder) -> None: - super().__init__() - self._py_builder = py_builder + @staticmethod + def is_available() -> bool: + try: + import triton + except ImportError: + return False + return True def can_build(self, sol: Solution) -> bool: - return sol.spec.language == SupportedLanguages.TRITON and self._get_triton_available() - - def _make_key(self, solution: Solution) -> str: - return f"triton::{create_pkg_name(solution)}" - - def _make_closer(self, *args, **kwargs): - raise NotImplementedError("Triton uses PythonBuilder's closer through _build") - - def _build(self, defn: Definition, sol: Solution) -> Runnable: - if not self._get_triton_available(): - raise BuildError("Triton is not available in the current environment") + return sol.spec.language == SupportedLanguages.TRITON - import triton + def get_key(self, solution: Solution) -> str: + return create_package_name(solution, self._KEY_PREFIX) - # Reuse Python builder for source layout and import - runnable = self._py_builder._build(defn, sol) - runnable.meta.update( - {"language": "triton", "triton_version": getattr(triton, "__version__", None)} - ) - return runnable + def build(self, definition: Definition, solution: Solution) -> Runnable: + result = super().build(definition, solution) + result.metadata.build_type = "triton" + return result diff --git a/flashinfer_bench/compile/builders/tvm_ffi_builder.py b/flashinfer_bench/compile/builders/tvm_ffi_builder.py index 3e0de8ff..4df6a025 100644 --- a/flashinfer_bench/compile/builders/tvm_ffi_builder.py +++ b/flashinfer_bench/compile/builders/tvm_ffi_builder.py @@ -3,32 +3,25 @@ from __future__ import annotations import logging +import shutil from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Callable, Dict, List, Tuple import tvm_ffi from tvm_ffi.utils import FileLock -from flashinfer_bench.compile.builder import Builder, BuildError, create_pkg_name -from flashinfer_bench.compile.runnable import Runnable, TVMFFIRunnable +from flashinfer_bench.compile.builder import Builder, BuildError +from flashinfer_bench.compile.runnable import Runnable, RunnableMetadata, TVMFFIRunnable +from flashinfer_bench.compile.utils import create_package_name, write_sources_to_path from flashinfer_bench.data import Definition, Solution, SupportedLanguages from flashinfer_bench.env import get_fib_cache_path logger = logging.getLogger(__name__) # File extension mappings for source file classification -CUDA_EXTENSIONS = [".cu"] # CUDA source files -CPP_EXTENSIONS = [".cpp", ".cc", ".cxx", ".c"] # C/C++ source files - - -class Language(Enum): - """Enum representing source code languages supported by the builder.""" - - CUDA = "cuda" - """The solution's language is CUDA""" - CPP = "cpp" - """The solution's language is C/C++""" +_CUDA_EXTENSIONS: List[str] = [".cu"] # CUDA source files +_CPP_EXTENSIONS: List[str] = [".cpp", ".cc", ".cxx", ".c"] # C/C++ source files class TVMFFIBuilder(Builder): @@ -70,6 +63,15 @@ def __init__(self) -> None: self._extra_include_paths: Dict[str, str] = {} self._extra_ldflags: Dict[str, List[str]] = {} + @staticmethod + def is_available() -> bool: + """Check if TVM-FFI is available in the current environment.""" + try: + import tvm_ffi + except ImportError: + return False + return True + def can_build(self, sol: Solution) -> bool: """Check if this builder can build the given solution. @@ -85,7 +87,7 @@ def can_build(self, sol: Solution) -> bool: """ return sol.spec.language == SupportedLanguages.CUDA - def _make_key(self, solution: Solution) -> str: + def get_key(self, solution: Solution) -> str: """Generate unique cache key for a solution. Parameters @@ -98,17 +100,7 @@ def _make_key(self, solution: Solution) -> str: str Unique key combining builder name and solution package name """ - return self._KEY_PREFIX + create_pkg_name(solution) - - def _make_closer(self): - """Create a closer function for resource cleanup. - - Returns - ------- - callable - No-op closer since TVM-FFI handles cleanup automatically - """ - return lambda: None + return create_package_name(solution, self._KEY_PREFIX) def _get_build_path(self, key: str) -> Path: """Get the build directory path for a given cache key. @@ -179,51 +171,13 @@ def _check_sources(self, path: Path, key: str, sol: Solution) -> bool: # All checks passed: can use cached .so return True - def _detect_language(self, sol: Solution) -> Language: - """Detect source language based on file extensions. + def _filter_sources(self, source_paths: List[Path]) -> Tuple[List[str], List[str]]: + """Filter source files by extension into C++ and CUDA source file paths. Parameters ---------- - sol : Solution - Solution containing source files - - Returns - ------- - Language - CUDA if any .cu files present, otherwise CPP - - Raises - ------ - BuildError - If no valid source files found - """ - has_cuda = False - has_cpp = False - - for src in sol.sources: - path_str = str(src.path) - if path_str.endswith(tuple(CUDA_EXTENSIONS)): - has_cuda = True - elif path_str.endswith(tuple(CPP_EXTENSIONS)): - has_cpp = True - - if not has_cuda and not has_cpp: - raise BuildError("No CUDA or C++ sources found") - - return Language.CUDA if has_cuda else Language.CPP - - def _write_sources(self, path: Path, sol: Solution) -> Tuple[List[str], List[str]]: - """Write all source files to build directory and collect file paths. - - Creates parent directories as needed for files in subdirectories. - Overwrites files unconditionally (caller already determined a full build is needed). - - Parameters - ---------- - path : Path - Build directory where source files will be written - sol : Solution - Solution containing source files to write + source_paths : List[Path] + List of source file paths. Returns ------- @@ -232,30 +186,13 @@ def _write_sources(self, path: Path, sol: Solution) -> Tuple[List[str], List[str cuda_files : List[str] List of CUDA source file paths """ - path.mkdir(parents=True, exist_ok=True) cpp_files: List[str] = [] cuda_files: List[str] = [] - - for src in sol.sources: - # Defensive assertion: path should be validated at Solution creation time - src_path_obj = Path(src.path) - assert not src_path_obj.is_absolute(), f"Absolute path detected: {src.path}" - assert ".." not in src_path_obj.parts, f"Path traversal detected: {src.path}" - - src_path = path / src.path - - # Ensure parent directories exist - src_path.parent.mkdir(parents=True, exist_ok=True) - - # Write source file - src_path.write_text(src.content) - - # Collect file path by extension - path_str = str(src_path) - if path_str.endswith(tuple(CPP_EXTENSIONS)): - cpp_files.append(path_str) - elif path_str.endswith(tuple(CUDA_EXTENSIONS)): - cuda_files.append(path_str) + for src_path in source_paths: + if src_path.suffix in _CPP_EXTENSIONS: + cpp_files.append(str(src_path)) + elif src_path.suffix in _CUDA_EXTENSIONS: + cuda_files.append(str(src_path)) return cpp_files, cuda_files @@ -284,52 +221,26 @@ def _get_entry_symbol(self, sol: Solution) -> str: ) return entry_point.split("::")[-1] - def _make_runnable( - self, mod: tvm_ffi.Module, entry_symbol: str, defn: Definition, metadata: Dict[str, Any] - ) -> TVMFFIRunnable: - """Create Runnable from TVM-FFI module. - - Wraps the compiled function with a keyword argument adapter that matches - the definition's input/output interface. + def _get_cleaner(self, build_path: Path) -> Callable[[], None]: + """Get a cleaner function for the build directory. It will remove the build directory. Parameters ---------- - mod : tvm_ffi.Module - Loaded TVM-FFI module containing the compiled function - entry_symbol : str - Name of the function to extract from the module - defn : Definition - Definition specifying the function interface - metadata : Dict[str, Any] - Metadata about the build (language, paths, etc.) + build_path : Path + The path to the build directory Returns ------- - TVMFFIRunnable - Runnable wrapper that handles tensor allocation and keyword arguments - - Raises - ------ - BuildError - If the entry_symbol is not found in the module + callable + A function that cleans up the build directory. """ - try: - fn = getattr(mod, entry_symbol) - except AttributeError as e: - raise BuildError(f"Entry point '{entry_symbol}' not found in module") from e - - # Create keyword adapter to match definition interface - arg_order = list(defn.inputs.keys()) + list(defn.outputs.keys()) - def _kw_adapter(**kwargs): - args = [kwargs[name] for name in arg_order] - return fn(*args) + def cleaner() -> None: + shutil.rmtree(build_path, ignore_errors=True) - return TVMFFIRunnable( - fn=_kw_adapter, closer=self._make_closer(), meta=metadata, definition=defn - ) + return cleaner - def _build(self, defn: Definition, sol: Solution) -> Runnable: + def build(self, definition: Definition, solution: Solution) -> Runnable: """Build with automatic caching - compile once, load from cache afterwards. This method implements intelligent caching: @@ -342,9 +253,9 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable: Parameters ---------- - defn : Definition + definition : Definition Problem definition specifying inputs/outputs - sol : Solution + solution : Solution Solution containing source code and build specification Returns @@ -357,14 +268,13 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable: BuildError If compilation fails, module loading fails, or entry point is invalid """ - key = self._make_key(sol) + key = self.get_key(solution) build_path = self._get_build_path(key) - entry_symbol = self._get_entry_symbol(sol) - language = self._detect_language(sol) - can_use_cached = self._check_sources(build_path, key, sol) + entry_symbol = self._get_entry_symbol(solution) + can_use_cached = self._check_sources(build_path, key, solution) - # Check if cached .so can be used - # This checking and rebuilding is thread-safe through the FileLock + # Check if cached .so can be used. If not, build the solution. + # This check and build are thread-safe through the FileLock if can_use_cached: output_lib_path = str(build_path / f"{key}.so") else: @@ -372,10 +282,11 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable: build_path.mkdir(parents=True, exist_ok=True) with FileLock(build_path / self._LOCK_FILE_NAME): # Double-check after acquiring lock (another process may have built it) - if self._check_sources(build_path, key, sol): + if self._check_sources(build_path, key, solution): output_lib_path = str(build_path / f"{key}.so") else: - cpp_files, cuda_files = self._write_sources(build_path, sol) + src_paths = write_sources_to_path(build_path, solution.sources) + cpp_files, cuda_files = self._filter_sources(src_paths) extra_include_paths = [str(build_path)] try: # Compile sources to shared library @@ -387,7 +298,9 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable: build_directory=build_path, ) except Exception as e: - raise BuildError(f"TVM-FFI compilation failed for '{sol.name}': {e}") from e + raise BuildError( + f"TVM-FFI compilation failed for '{solution.name}': {e}" + ) from e # Load the compiled module try: @@ -396,14 +309,22 @@ def _build(self, defn: Definition, sol: Solution) -> Runnable: raise BuildError(f"Failed to load compiled module: {e}") from e # Create metadata for the runnable - metadata = { - "definition": defn.name, - "solution": sol.name, - "language": language.value, - "binding": "tvm_ffi", - "key": key, - "symbol": entry_symbol, - "binary": output_lib_path, - } - - return self._make_runnable(mod, entry_symbol, defn, metadata) + metadata = RunnableMetadata( + build_type="tvm_ffi", + definition=definition.name, + solution=solution.name, + misc={ + "definition": definition, + "key": key, + "symbol": entry_symbol, + "binary": output_lib_path, + }, + ) + + try: + fn = getattr(mod, entry_symbol) + except AttributeError as e: + raise BuildError(f"Entry point '{entry_symbol}' not found in module") from e + + cleaner = self._get_cleaner(build_path) + return Runnable(callable=fn, metadata=metadata, cleaner=cleaner) diff --git a/flashinfer_bench/compile/registry.py b/flashinfer_bench/compile/registry.py index 0e108e77..4e4468f9 100644 --- a/flashinfer_bench/compile/registry.py +++ b/flashinfer_bench/compile/registry.py @@ -1,33 +1,63 @@ from __future__ import annotations -from typing import Tuple +from typing import ClassVar, Dict, List, Type from flashinfer_bench.data import BuildSpec, Definition, Solution, SourceFile, SupportedLanguages from .builder import Builder, BuildError +from .builders import PythonBuilder, TorchBuilder, TritonBuilder, TVMFFIBuilder from .runnable import Runnable +_BUILDER_PRIORITY: List[Type[Builder]] = [TritonBuilder, PythonBuilder, TVMFFIBuilder, TorchBuilder] +"""Contains all builders in the order of priority.""" + class BuilderRegistry: """Registry that dispatches to the first capable builder.""" - def __init__(self, builders: Tuple[Builder, ...]) -> None: - if not builders: + _instance: ClassVar["BuilderRegistry" | None] = None + """Singleton instance of the BuilderRegistry.""" + + _builders: List[Builder] + """List of builders in the order of priority.""" + _cache: Dict[str, Runnable] + """Cache of built runnables.""" + + def __init__(self, builders: List[Builder]) -> None: + if len(builders) == 0: raise ValueError("BuilderRegistry requires at least one builder") - self._builders: Tuple[Builder, ...] = builders + self._builders = list(builders) + self._cache: Dict[str, Runnable] = {} def clear(self) -> None: - for b in self._builders: + for runnable in self._cache.values(): try: - b.clear_cache() + runnable.cleanup() except Exception: pass + self._cache.clear() + + @classmethod + def get_instance(cls) -> "BuilderRegistry": + if cls._instance is None: + builders = [] + for builder_type in _BUILDER_PRIORITY: + if builder_type.is_available(): + builders.append(builder_type()) + cls._instance = BuilderRegistry(builders) + return cls._instance def build(self, defn: Definition, sol: Solution) -> Runnable: + hash = sol.hash() + if hash in self._cache: + return self._cache[hash] + for builder in self._builders: - # Choose the first + # Choose the first builder that can build the solution if builder.can_build(sol): - return builder.build(defn, sol) + runnable = builder.build(defn, sol) + self._cache[hash] = runnable + return runnable raise BuildError(f"No registered builder can build solution '{sol.name}'") def build_reference(self, defn: Definition) -> Runnable: @@ -44,21 +74,3 @@ def build_reference(self, defn: Definition) -> Runnable: description="reference", ) return self.build(defn, pseudo) - - -_registry: BuilderRegistry | None = None - - -def get_builder_registry() -> BuilderRegistry: - global _registry - if _registry is None: - from .builders import CUDABuilder, PythonBuilder, TritonBuilder, TVMFFIBuilder - - py = PythonBuilder() - triton = TritonBuilder(py_builder=py) - tvm_ffi = TVMFFIBuilder() - cuda = CUDABuilder() # Fallback for backward compatibility - - # Priority: Python > Triton > TVM-FFI > CUDA (pybind11) - _registry = BuilderRegistry((py, triton, tvm_ffi, cuda)) - return _registry diff --git a/flashinfer_bench/compile/runnable.py b/flashinfer_bench/compile/runnable.py index fe309df9..9016341c 100644 --- a/flashinfer_bench/compile/runnable.py +++ b/flashinfer_bench/compile/runnable.py @@ -1,63 +1,124 @@ from __future__ import annotations -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Literal, Optional + +from pydantic import BaseModel from flashinfer_bench.data import Definition from flashinfer_bench.utils import dtype_str_to_torch_dtype +BuildType = Literal["cuda", "tvm_ffi", "python", "triton"] +"""The type of build that produced this runnable. Each builder has a unique build type.""" + + +class RunnableMetadata(BaseModel): + """Metadata about the runnable.""" + + build_type: BuildType + """The type of build that produced this runnable.""" + definition: str + """The definition that was used to build this runnable.""" + solution: str + """The solution that was used to build this runnable.""" + misc: Dict[str, Any] + """Miscellaneous metadata about the runnable.""" + class Runnable: + """A callable that is compiled from a solution. The runnable contains a callable, metadata, + and a closer function.""" + + metadata: RunnableMetadata + """The metadata for the runnable.""" + + _callable: Callable[..., Any] + """The callable that is wrapped by the runnable.""" + _closer: Optional[Callable[[], None]] + """The closer function for the runnable.""" + def __init__( - self, fn: Callable[..., Any], closer: Optional[Callable[[], None]], meta: Dict[str, Any] + self, + callable: Callable[..., Any], + metadata: RunnableMetadata, + cleaner: Optional[Callable[[], None]] = None, ) -> None: - """A runnable callable with a required resource closer. - - closer: must be provided by the builder and be idempotent. + """Constructor for the Runnable class. + + Parameters + ---------- + callable : Callable[..., Any] + The callable that is wrapped by the runnable. + metadata : RunnableMetadata + The metadata for the runnable. + cleaner : Optional[Callable[[], None]] + The cleaner function for the runnable. It will clean up the build artifacts/resources. """ - self._fn = fn - self._closer: Optional[Callable[[], None]] = closer - self.meta: Dict[str, Any] = meta + self._callable = callable + self.metadata = metadata + self._cleaner = cleaner def __call__(self, **kwargs: Any) -> Any: """ - - Accept kwargs only (aligns with Definition.inputs naming) - - Unpack a single-element tuple to a scalar value - - No type/shape/count validation; errors surface naturally + Call the underlying function, and return the result. If the result is a single-element + tuple, unpack it. + + Parameters + ---------- + args : Any + The positional arguments to pass to the underlying function. + kwargs : Any + The keyword arguments to pass to the underlying function. + + Returns + ------- + Any + The result of the underlying function. If the result is a single-element tuple, + unpack it to a scalar value. """ - ret = self._fn(**kwargs) + ret = self._callable(**kwargs) if isinstance(ret, tuple) and len(ret) == 1: return ret[0] return ret - def close(self) -> None: - """Release build artifacts/resources; must be idempotent.""" - if self._closer: - try: - self._closer() - finally: - self._closer = None + def call_dps(self, **kwargs: Any) -> Any: + """Call a destination-passing style (DPS) function in value-returning style. + This method assumes the callable is destination-passing style:: -class TVMFFIRunnable(Runnable): - def __init__( - self, - fn: Callable[..., Any], - closer: Optional[Callable[[], None]], - meta: Dict[str, Any], - definition: Definition, - ) -> None: - super().__init__(fn, closer, meta) - self._definition = definition + function(**kwargs, **output_tensors) -> None - def __call__(self, **kwargs: Any) -> Any: + And calling this method will call the DPS function in value-returning style: + + runnable.call_dps(**kwargs) -> output_tensors + + It will internally allocate output tensors, call the callable with the provided inputs + and allocated output tensors, and return the results. + + Parameters + ---------- + kwargs : Any + The keyword arguments to pass to the underlying function. + + Returns + ------- + Any + """ import torch - # Allocate output tensors first + if "definition" not in self.metadata.misc or not isinstance( + self.metadata.misc["definition"], Definition + ): + raise ValueError( + "When calling in destination passing style, metadata.misc must " + "contain the full definition." + ) + definition: Definition = self.metadata.misc["definition"] - var_values = self._definition.get_var_values( + # Allocate output tensors first + var_values = definition.get_var_values( {name: list(tensor.shape) for name, tensor in kwargs.items()} ) - output_shapes = self._definition.get_output_shapes(var_values) + output_shapes = definition.get_output_shapes(var_values) output_tensors: Dict[str, torch.Tensor] = {} # Determine device from input tensors @@ -68,16 +129,22 @@ def __call__(self, **kwargs: Any) -> Any: for name, shape in output_shapes.items(): output_tensors[name] = torch.empty( - shape, dtype=dtype_str_to_torch_dtype(self._definition.outputs[name].dtype) + shape, dtype=dtype_str_to_torch_dtype(definition.outputs[name].dtype) ).to(device) - self.call_dest(**kwargs, **output_tensors) + self._callable(**kwargs, **output_tensors) - results = list(output_tensors.values()) + results = tuple(output_tensors.values()) + if len(results) == 0: + return None if len(results) == 1: return results[0] return results - def call_dest(self, **kwargs: Any) -> None: - """Call the underlying function with destination passing style.""" - self._fn(**kwargs) + def cleanup(self) -> None: + """Clean up the build artifacts/resources.""" + if self._closer: + try: + self._closer() + finally: + self._closer = None diff --git a/flashinfer_bench/compile/utils.py b/flashinfer_bench/compile/utils.py new file mode 100644 index 00000000..ff595d5f --- /dev/null +++ b/flashinfer_bench/compile/utils.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import hashlib +import re +from pathlib import Path +from typing import List + +from flashinfer_bench.data import Solution, SourceFile + + +def write_sources_to_path(path: Path, sources: List[SourceFile]) -> List[Path]: + """Write a list of source files to the given directory. + + Creates parent directories as needed for files in subdirectories. + Overwrites files unconditionally (caller already determined a full build is needed). + Each source file should not contain parent directory traversal ("..") or absolute paths, and + should be unique. + + Parameters + ---------- + path : Path + The directory path to write the source files to. + sources : list[SourceFile] + The list of source files to write. + """ + path.mkdir(parents=True, exist_ok=True) + paths: List[Path] = [] + for src in sources: + # Defensive assertion: path should be validated at Solution creation time + src_path_obj = Path(src.path) + + assert not src_path_obj.is_absolute(), f"Absolute path detected: {src.path}" + assert ".." not in src_path_obj.parts, f"Path traversal detected: {src.path}" + + src_path = path / src.path + + # Ensure parent directories exist + src_path.parent.mkdir(parents=True, exist_ok=True) + + # Write source file + src_path.write_text(src.content) + paths.append(src_path) + + return paths + + +def create_package_name(sol: Solution, prefix: str = "") -> str: + """Create a package name for a solution. The name is created by normalizing the solution name + and hashing the sources. + + Parameters + ---------- + sol : Solution + The solution to create a package name for. + prefix : str + The prefix to add to the package name. + + Returns + ------- + str + The package name for the solution. + """ + # Normalize the solution name + s = re.sub(r"[^0-9a-zA-Z_]", "_", sol.name) + if not s or s[0].isdigit(): + s = "_" + s + + return prefix + s + "_" + sol.hash()[:6] diff --git a/flashinfer_bench/data/solution.py b/flashinfer_bench/data/solution.py index 0c3d3fcb..eb3aa0d8 100644 --- a/flashinfer_bench/data/solution.py +++ b/flashinfer_bench/data/solution.py @@ -1,5 +1,6 @@ """Strong-typed data definitions for solution implementations.""" +import hashlib from enum import Enum from pathlib import Path from typing import List, Optional @@ -74,8 +75,8 @@ class BuildSpec(BaseModelWithDocstrings): """The exact path to the function to be called. Format: '{file_path}::{function_name}' (e.g., 'main.py::run').""" dependencies: Optional[List[NonEmptyString]] = Field(default=[]) - """Optional list of required libraries or toolchains (e.g., 'CUDA >= 12.0', - 'triton >= 2.2').""" + """Optional list of required libraries or packages. E.g. for CUDA, we support 'cublas', + 'cudnn', 'cutlass'""" @model_validator(mode="after") def _validate_entry_point(self) -> "BuildSpec": @@ -139,6 +140,26 @@ def _validate_source_path_entry_point(self) -> "Solution": return self + def get_entry_path(self) -> Path: + """Get the path to the entry source file. + + Returns + ------- + str + The path to the entry source file. + """ + return Path(self.spec.entry_point.split("::")[0]) + + def get_entry_symbol(self) -> str: + """Extract function symbol from entry_point. + + Returns + ------- + str + The function symbol name to be loaded from the compiled module + """ + return self.spec.entry_point.split("::")[-1] + def get_entry_source(self) -> Optional[SourceFile]: """Get the entry source file specified in the build spec. @@ -147,19 +168,29 @@ def get_entry_source(self) -> Optional[SourceFile]: Optional[SourceFile] The SourceFile object containing the entry point, or None if not found. """ - entry_path = self.spec.entry_point.split("::")[0] + entry_path = self.get_entry_path() for source in self.sources: if source.path == entry_path: return source return None - def requires_build(self) -> bool: - """Check if the solution requires a build step. + def hash(self) -> str: + """Hash the solution. It returns the SHA1 hash of the solution. Returns ------- - bool - True if the solution requires building (has build commands or uses CUDA), - False otherwise. + str + The hash of the solution. """ - return self.spec.language == SupportedLanguages.CUDA + h = hashlib.sha1() + for s in ( + self.name, + self.definition, + self.spec.language, + self.spec.entry_point, + *self.spec.dependencies, + *(part for src in self.sources for part in (src.path, src.content)), + ): + h.update(s.encode()) + + return h.hexdigest() diff --git a/flashinfer_bench/data/trace_set.py b/flashinfer_bench/data/trace_set.py index 0da875c5..85933652 100644 --- a/flashinfer_bench/data/trace_set.py +++ b/flashinfer_bench/data/trace_set.py @@ -18,7 +18,6 @@ from .trace import EvaluationStatus, Trace -# TODO(shanli): TraceSet wide validation @dataclass class TraceSet: """Stores a FlashInfer Trace dataset containing definitions, solutions, workloads, and traces. diff --git a/tests/apply/test_runtime.py b/tests/apply/test_runtime.py index ff57bb57..b539f5b3 100644 --- a/tests/apply/test_runtime.py +++ b/tests/apply/test_runtime.py @@ -243,7 +243,7 @@ def ev(sp: float) -> Evaluation: from flashinfer_bench.compile.builders.python_builder import PythonBuilder counts = {"build": 0} - orig_build = PythonBuilder._build + orig_build = PythonBuilder.build def counting_build(self, definition: Definition, solution: Solution) -> Runnable: counts["build"] += 1 @@ -251,7 +251,7 @@ def counting_build(self, definition: Definition, solution: Solution) -> Runnable try: # Patch at class so the registry instance picks it up - PythonBuilder._build = counting_build # type: ignore[assignment] + PythonBuilder.build = counting_build # type: ignore[assignment] # Two dispatches for same key should reuse cached runnable class T: @@ -269,7 +269,7 @@ def __init__(self, shape: Tuple[int, ...]): # Only one real build should have occurred assert counts["build"] == 1 finally: - PythonBuilder._build = orig_build # type: ignore[assignment] + PythonBuilder.build = orig_build # type: ignore[assignment] set_apply_runtime(None) diff --git a/tests/compile/test_builder.py b/tests/compile/test_builder.py index 2e32b246..1b868fc5 100644 --- a/tests/compile/test_builder.py +++ b/tests/compile/test_builder.py @@ -18,14 +18,16 @@ class DummyBuilder(Builder): def can_build(self, solution: Solution) -> bool: return True - def _make_key(self, solution: Solution) -> str: + def get_key(self, solution: Solution) -> str: return f"dummy::{solution.name}" def _make_closer(self): return lambda: None def _build(self, definition: Definition, solution: Solution) -> Runnable: - return Runnable(fn=lambda **kw: kw, closer=self._make_closer(), meta={"dummy": True}) + return Runnable( + callable=lambda **kw: kw, cleaner=self._make_closer(), metadata={"dummy": True} + ) def test_builder_cache_and_key(): @@ -43,9 +45,9 @@ def test_builder_cache_and_key(): ) srcs = [SourceFile(path="main.py", content="def run(A):\n return A\n")] s = Solution(name="s1", definition="test_def", author="me", spec=spec, sources=srcs) - r1 = b.build(d, s) - r2 = b.build(d, s) - assert r1 is r2 # cache hit via _make_key + r1 = b.build_with_cache(d, s) + r2 = b.build_with_cache(d, s) + assert r1 is r2 # cache hit via get_key b.clear_cache() diff --git a/tests/compile/test_python_builder.py b/tests/compile/test_python_builder.py index 99bb03d8..cc476c81 100644 --- a/tests/compile/test_python_builder.py +++ b/tests/compile/test_python_builder.py @@ -38,15 +38,12 @@ def test_python_builder_minimum(tmp_path, monkeypatch): s = Solution(name="py_sol", definition="mm", author="me", spec=spec, sources=srcs) b = PythonBuilder() - r = b.build(d, s) + r = b.build_with_cache(d, s) # Call runnable with torch tensors A = [[1, 2], [3, 4]] B = [[0, 0], [0, 0]] out = r(A=A, B=B) assert out == A - # Ensure temp_dir recorded under our cache - assert r.meta.get("temp_dir") - assert str(cache_dir) in r.meta["temp_dir"] # Cleanup b.clear_cache() @@ -80,7 +77,7 @@ def test_python_builder_add(tmp_path, monkeypatch): # Build and run with torch tensors b = PythonBuilder() - r = b.build(defn, sol) + r = b.build_with_cache(defn, sol) X = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) Y = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32) out = r(X=X, Y=Y) diff --git a/tests/compile/test_runnable.py b/tests/compile/test_runnable.py index ebed072e..92b82ef3 100644 --- a/tests/compile/test_runnable.py +++ b/tests/compile/test_runnable.py @@ -2,7 +2,7 @@ import pytest -from flashinfer_bench.compile.runnable import Runnable +from flashinfer_bench.compile.runnable import Runnable, RunnableMetadata def test_runnable_single_tuple_unpack_and_close_idempotent(): @@ -14,12 +14,16 @@ def fn(**kw): def closer(): calls["closed"] += 1 - r = Runnable(fn=fn, closer=closer, meta={"k": 1}) + metadata = RunnableMetadata( + build_type="python", definition="test", solution="test", misc={"k": 1} + ) + + r = Runnable(callable=fn, cleaner=closer, metadata=metadata) assert r() == 42 # Close twice should not error and closer should be called once - r.close() - r.close() - r.close() + r.cleanup() + r.cleanup() + r.cleanup() assert calls["closed"] == 1 diff --git a/tests/compile/test_triton_builder.py b/tests/compile/test_triton_builder.py index 728feebb..b0a241f9 100644 --- a/tests/compile/test_triton_builder.py +++ b/tests/compile/test_triton_builder.py @@ -54,7 +54,7 @@ def mock_import(name, *args, **kwargs): monkeypatch.setattr(builtins, "__import__", mock_import) with pytest.raises(BuildError, match="Triton is not available"): - b.build(d, s) + b.build_with_cache(d, s) @pytest.mark.skipif(importlib.util.find_spec("triton") is None, reason="Triton not available") @@ -74,7 +74,7 @@ def test_triton_builder_minimum(tmp_path, monkeypatch): ) srcs = [SourceFile(path="m/main.py", content="import torch\n\ndef run(A):\n return A")] s = Solution(name="tri_ok", definition="d", author="a", spec=spec, sources=srcs) - r = b.build(d, s) + r = b.build_with_cache(d, s) out = r(A=[1, 2, 3]) assert out == [1, 2, 3] @@ -138,7 +138,7 @@ def run(X, Y): ) b = TritonBuilder(PythonBuilder()) - r = b.build(defn, sol) + r = b.build_with_cache(defn, sol) X = torch.arange(256, dtype=torch.float32, device="cuda") Y = 2 * torch.ones(256, dtype=torch.float32, device="cuda") Z = r(X=X, Y=Y) diff --git a/tests/compile/test_tvm_ffi_builder.py b/tests/compile/test_tvm_ffi_builder.py index 3047a70b..d5827794 100644 --- a/tests/compile/test_tvm_ffi_builder.py +++ b/tests/compile/test_tvm_ffi_builder.py @@ -120,7 +120,7 @@ def test_build_cpp_cpu() -> None: # Build and run builder = TVMFFIBuilder() - runnable = builder.build(definition, solution) + runnable = builder.build_with_cache(definition, solution) # Test execution with torch tensors - runnable returns output input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu", dtype=torch.float32) @@ -167,7 +167,7 @@ def test_build_cuda_gpu() -> None: # Build and run builder = TVMFFIBuilder() - runnable = builder.build(definition, solution) + runnable = builder.build_with_cache(definition, solution) # Test execution with torch tensors - runnable returns output n = 1024 @@ -251,13 +251,13 @@ def test_caching_builder_level() -> None: # First build builder = TVMFFIBuilder() time_start = time.monotonic() - runnable1 = builder.build(definition, solution) + runnable1 = builder.build_with_cache(definition, solution) time_end = time.monotonic() print(f"Time taken to build: {(time_end - time_start) * 1000} ms") # Second build should load from cache time_start = time.monotonic() - runnable2 = builder.build(definition, solution) + runnable2 = builder.build_with_cache(definition, solution) time_end = time.monotonic() print(f"Time taken to load from cache: {(time_end - time_start) * 1000} ms") @@ -300,14 +300,14 @@ def test_caching_cross_builder() -> None: # First build builder1 = TVMFFIBuilder() time_start = time.monotonic() - runnable1 = builder1.build(definition, solution) + runnable1 = builder1.build_with_cache(definition, solution) time_end = time.monotonic() print(f"Time taken to build: {(time_end - time_start) * 1000} ms") # Second build should load from cache builder2 = TVMFFIBuilder() time_start = time.monotonic() - runnable2 = builder2.build(definition, solution) + runnable2 = builder2.build_with_cache(definition, solution) time_end = time.monotonic() print(f"Time taken to load from cache: {(time_end - time_start) * 1000} ms") @@ -350,7 +350,7 @@ def test_call_dest_cpu() -> None: # Build builder = TVMFFIBuilder() - runnable = builder.build(definition, solution) + runnable = builder.build_with_cache(definition, solution) # Manually allocate input and output tensors input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], device="cpu", dtype=torch.float32) @@ -401,7 +401,7 @@ def test_invalid_entry_point() -> None: builder = TVMFFIBuilder() with pytest.raises(BuildError): - builder.build(definition, invalid_solution) + builder.build_with_cache(definition, invalid_solution) def test_no_sources() -> None: @@ -433,7 +433,7 @@ def test_no_sources() -> None: builder = TVMFFIBuilder() with pytest.raises(BuildError, match="No CUDA or C\\+\\+ sources"): - builder.build(definition, no_sources_solution) + builder.build_with_cache(definition, no_sources_solution) def test_source_in_subdirectory() -> None: @@ -466,7 +466,7 @@ def test_source_in_subdirectory() -> None: # Build and run builder = TVMFFIBuilder() - runnable = builder.build(definition, solution) + runnable = builder.build_with_cache(definition, solution) # Test execution input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], device="cpu", dtype=torch.float32) diff --git a/tests/data/test_solution.py b/tests/data/test_solution.py index 8cf41528..1e59898d 100644 --- a/tests/data/test_solution.py +++ b/tests/data/test_solution.py @@ -59,14 +59,12 @@ def test_solution_validation_and_helpers(): s2 = SourceFile(path="util.py", content="x=1\n") sol = Solution(name="sol1", definition="def1", author="me", spec=spec, sources=[s1, s2]) assert sol.get_entry_source() is s1 - assert sol.requires_build() is False # CUDA requires build cuda_spec = BuildSpec( language=SupportedLanguages.CUDA, target_hardware=["cuda"], entry_point="main.py::run" ) sol2 = Solution(name="sol2", definition="def1", author="you", spec=cuda_spec, sources=[s1]) - assert sol2.requires_build() is True # Duplicate source paths with pytest.raises(ValueError): From d9c9328f8aa870f10d3d1e188398b40c12e3355e Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 30 Nov 2025 18:10:18 -0500 Subject: [PATCH 2/2] update Signed-off-by: Ubospica --- flashinfer_bench/compile/__init__.py | 12 +- flashinfer_bench/compile/builder.py | 82 +++++-- flashinfer_bench/compile/builders/__init__.py | 2 + .../compile/builders/python_builder.py | 71 +++++- .../compile/builders/torch_builder.py | 207 +++++++++++++----- .../compile/builders/triton_builder.py | 46 +++- .../compile/builders/tvm_ffi_builder.py | 84 +++---- flashinfer_bench/compile/registry.py | 112 +++++++++- flashinfer_bench/compile/runnable.py | 78 ++++--- flashinfer_bench/compile/utils.py | 51 +++-- flashinfer_bench/data/solution.py | 31 ++- 11 files changed, 590 insertions(+), 186 deletions(-) diff --git a/flashinfer_bench/compile/__init__.py b/flashinfer_bench/compile/__init__.py index 1e0a4801..f4520ae0 100644 --- a/flashinfer_bench/compile/__init__.py +++ b/flashinfer_bench/compile/__init__.py @@ -1,6 +1,16 @@ """Compiler subsystem package. -Exports common builder types for convenience. +This package provides the infrastructure for building solutions into executable runnables. +It includes: +- Builder: Abstract base class for different language/build system implementations +- BuilderRegistry: Central registry for managing and dispatching builders +- Runnable: Executable wrapper around compiled solutions +- RunnableMetadata: Metadata about build process and source + +The typical workflow is: +1. Get the singleton registry: registry = BuilderRegistry.get_instance() +2. Build a solution: runnable = registry.build(definition, solution) +3. Execute: result = runnable(**inputs) """ from .builder import Builder, BuildError diff --git a/flashinfer_bench/compile/builder.py b/flashinfer_bench/compile/builder.py index 37cc967c..236bb548 100644 --- a/flashinfer_bench/compile/builder.py +++ b/flashinfer_bench/compile/builder.py @@ -1,11 +1,16 @@ +"""Abstract base class for solution builders.""" + from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict +from pathlib import Path +from typing import Dict, Tuple from flashinfer_bench.data import Definition, Solution +from flashinfer_bench.env import get_fib_cache_path from .runnable import Runnable +from .utils import create_package_name class BuildError(RuntimeError): @@ -13,25 +18,76 @@ class BuildError(RuntimeError): class Builder(ABC): - """Builder abstraction: (Definition, Solution) -> Runnable with hidden cache.""" + """Abstract base class for building solutions into runnable implementations. - @abstractmethod - def __init__(self) -> None: - """Constructor for the Builder class.""" - ... + A Builder transforms a (Definition, Solution) pair into a Runnable object, which is an + executable implementation of the solution. Different builders handle different programming + languages (e.g., Python, CUDA, Triton) and build systems. + + Subclasses must implement all its abstract methods. Expectedly, the concrete builder should + operate in the folder `FIB_CACHE_PATH / builder_specific_subfolder / key`, where `key` is + a unique identifier for the solution. + """ + + def __init__(self, key_prefix: str, build_dir_name: str) -> None: + """Initialize the builder.""" + self._key_prefix = key_prefix + self._build_dir_name = build_dir_name @abstractmethod def can_build(self, solution: Solution) -> bool: - """Build guard to check if this builder can handle the given solution.""" - ... + """Check if this builder can handle the given solution. - @abstractmethod - def get_key(self, solution: Solution) -> str: - """Get a unique key for a solution. This key is used to cache the build results and - used as the module name.""" + Parameters + ---------- + solution : Solution + The solution to check. + + Returns + ------- + bool + True if this builder can build the solution, False otherwise. + """ ... @abstractmethod def build(self, definition: Definition, solution: Solution) -> Runnable: - """Perform a real build and return a Runnable; raise BuildError on failure.""" + """Build a solution into a runnable implementation. + + This method compiles/loads the solution's source code and returns a Runnable + object that can be executed with the interface specified by the definition. + + Parameters + ---------- + definition : Definition + The problem definition that specifies the expected interface. + solution : Solution + The solution implementation to build. + + Returns + ------- + Runnable + An executable wrapper around the built implementation. + + Raises + ------ + BuildError + If the build fails for any reason (compilation errors, missing dependencies, etc.). + """ ... + + def get_package_name_and_build_path(self, solution: Solution) -> Tuple[str, Path]: + """Get the package name and build path for the solution. + + Parameters + ---------- + solution : Solution + The solution to get the package name and build path for. + + Returns + ------- + Tuple[str, Path]: The package name and build path for the solution. + """ + package_name = create_package_name(solution, self._key_prefix) + build_path = get_fib_cache_path() / self._build_dir_name / package_name + return package_name, build_path diff --git a/flashinfer_bench/compile/builders/__init__.py b/flashinfer_bench/compile/builders/__init__.py index 5932dffb..76cddb42 100644 --- a/flashinfer_bench/compile/builders/__init__.py +++ b/flashinfer_bench/compile/builders/__init__.py @@ -1,3 +1,5 @@ +"""Concrete builder implementations for different languages and build systems.""" + from .python_builder import PythonBuilder from .torch_builder import TorchBuilder from .triton_builder import TritonBuilder diff --git a/flashinfer_bench/compile/builders/python_builder.py b/flashinfer_bench/compile/builders/python_builder.py index f845b52e..3bff9f6f 100644 --- a/flashinfer_bench/compile/builders/python_builder.py +++ b/flashinfer_bench/compile/builders/python_builder.py @@ -1,3 +1,5 @@ +"""Builder for pure Python solutions.""" + from __future__ import annotations import importlib @@ -15,7 +17,12 @@ class PythonBuilder(Builder): - """Load a Python entry point from provided sources into a temporary module.""" + """Builder for Python solutions. + + This builder loads Python source files into a temporary module and returns a callable + that can be executed. The sources are written to a cache directory and imported as a + Python package. + """ _BUILD_DIR_NAME: ClassVar[str] = "python" """Subdirectory under FIB_CACHE_PATH where build results are stored.""" @@ -24,16 +31,51 @@ class PythonBuilder(Builder): """Prefix for cache keys to avoid collisions with other builders. fib_ prefix is added to avoid name collision in python imports.""" + def __init__(self) -> None: + super().__init__(self._KEY_PREFIX, self._BUILD_DIR_NAME) + def can_build(self, sol: Solution) -> bool: + """Check if this builder can handle the given solution.""" return sol.spec.language == SupportedLanguages.PYTHON def get_key(self, solution: Solution) -> str: + """Generate a unique cache key for the solution.""" return create_package_name(solution, self._KEY_PREFIX) def _get_build_path(self, key: str) -> Path: + """Get the build directory path for a given cache key. + + Parameters + ---------- + key : str + The unique cache key for the solution. + + Returns + ------- + Path + The directory path where the build artifacts will be stored. + """ return get_fib_cache_path() / self._BUILD_DIR_NAME / key def _get_cleaner(self, package: str, build_path: Path) -> Callable[[], None]: + """Create a cleaner function that removes build artifacts. + + The cleaner unloads the imported module, removes it from sys.path, and + deletes the build directory. + + Parameters + ---------- + package : str + The package name to unload from sys.modules. + build_path : Path + The directory to delete. + + Returns + ------- + Callable[[], None] + A function that performs the cleanup. + """ + def cleaner() -> None: try: # Unload module and submodules @@ -53,15 +95,36 @@ def cleaner() -> None: return cleaner def build(self, definition: Definition, solution: Solution) -> Runnable: + """Build a Python solution into a runnable. + + This method writes the solution sources to a temporary directory, imports the + module, and extracts the entry point function. + + Parameters + ---------- + definition : Definition + The problem definition. + solution : Solution + The Python solution to build. + + Returns + ------- + Runnable + An executable wrapper around the Python function. + + Raises + ------ + BuildError + If the entry file is not a Python file, the module import fails, or the + entry symbol is not found or not callable. + """ entry_file = solution.get_entry_path() entry_symbol = solution.get_entry_symbol() if entry_file.suffix != ".py": raise BuildError(f"Entry file '{entry_file}' is not a Python file") - # fib_python_some_solution_ - package_name = self.get_key(solution) - # fib_python_some_solution_.entry_file + package_name, build_path = self.get_package_name_and_build_path(solution) module_name = package_name + "." + ".".join(Path(entry_file).with_suffix("").parts) build_path = self._get_build_path(package_name) diff --git a/flashinfer_bench/compile/builders/torch_builder.py b/flashinfer_bench/compile/builders/torch_builder.py index 4d8ce696..6a1e60d9 100644 --- a/flashinfer_bench/compile/builders/torch_builder.py +++ b/flashinfer_bench/compile/builders/torch_builder.py @@ -1,3 +1,5 @@ +"""Builder for CUDA solutions using PyTorch's C++/CUDA extension system.""" + from __future__ import annotations import logging @@ -26,50 +28,14 @@ } -def _get_package_paths(pkg_name: str, lib_names: Optional[List[str]] = None): - include_path = None - ldflags = [] - - try: - include_dir = resources.files(pkg_name) / "include" - if include_dir.exists(): - include_path = str(include_dir) - - if lib_names: - lib_dir = resources.files(pkg_name) / "lib" - if lib_dir.exists(): - lib_path = Path(lib_dir) - - if sys.platform.startswith("linux"): - ldflags = [f"-L{lib_path}", f"-Wl,-rpath,{lib_path}"] - - for lib_name in lib_names: - # Look for unversioned .so first - lib_file = lib_path / f"lib{lib_name}.so" - if lib_file.exists(): - ldflags.append(f"-l{lib_name}") - else: - # Find versioned .so files - versioned = sorted(lib_path.glob(f"lib{lib_name}.so.*")) - if versioned: - ldflags.append(f"-l:{versioned[-1].name}") - else: - ldflags.append(f"-l{lib_name}") # Fallback - - elif sys.platform == "win32": - ldflags = [f"/LIBPATH:{lib_path}"] + lib_names - - except Exception: - logger.warning( - "Failed to discover resources for CUDA package '%s'; continuing without it.", - pkg_name, - exc_info=True, - ) - - return include_path, ldflags +class TorchBuilder(Builder): + """Builder for CUDA solutions using PyTorch's C++/CUDA extension loader. + This builder compiles C++/CUDA source files into a Python extension module using + torch.utils.cpp_extension.load(). It supports common CUDA dependencies like cuBLAS, + cuDNN, and CUTLASS. + """ -class TorchBuilder(Builder): _BUILD_DIR_NAME: ClassVar[str] = "torch" """Subdirectory under FIB_CACHE_PATH where build results are stored""" @@ -82,22 +48,98 @@ class TorchBuilder(Builder): """Extra link flags for CUDA dependencies""" def __init__(self) -> None: - super().__init__() + """Initialize the TorchBuilder and discover available CUDA dependencies.""" + super().__init__(self._KEY_PREFIX, self._BUILD_DIR_NAME) self._discover_cuda_deps() def _discover_cuda_deps(self): + """Discover available CUDA dependencies and their include/library paths. + + This method populates _extra_include_paths and _extra_ldflags with paths + for dependencies like cuBLAS, cuDNN, and CUTLASS if they are installed. + """ self._extra_include_paths = {} self._extra_ldflags = {} for dep_name, (pkg_name, libs) in _CUDA_DEPS.items(): - include_path, ldflags = _get_package_paths(pkg_name, libs) + include_path, ldflags = self._get_package_paths(pkg_name, libs) if include_path: self._extra_include_paths[dep_name] = include_path if ldflags: self._extra_ldflags[dep_name] = ldflags + def _get_package_paths( + self, pkg_name: str, lib_names: Optional[List[str]] = None + ) -> Tuple[Optional[str], List[str]]: + """Discover include and library paths for a given package. + + This function searches for the include and lib directories in a Python package + and returns the paths and linker flags needed to use them. + + Parameters + ---------- + pkg_name : str + The Python package name to search for (e.g., 'nvidia.cublas'). + lib_names : Optional[List[str]] + List of library names to link against (e.g., ['cublas', 'cublasLt']). + If None, only include paths are returned. + + Returns + ------- + Tuple[Optional[str], List[str]] + A tuple of (include_path, ldflags) where include_path is the path to the + include directory (or None if not found) and ldflags is a list of linker + flags to use the libraries. + """ + include_path = None + ldflags = [] + + try: + include_dir = resources.files(pkg_name) / "include" + if include_dir.exists(): + include_path = str(include_dir) + + if lib_names: + lib_dir = resources.files(pkg_name) / "lib" + if lib_dir.exists(): + lib_path = Path(lib_dir) + + if sys.platform.startswith("linux"): + ldflags = [f"-L{lib_path}", f"-Wl,-rpath,{lib_path}"] + + for lib_name in lib_names: + # Look for unversioned .so first + lib_file = lib_path / f"lib{lib_name}.so" + if lib_file.exists(): + ldflags.append(f"-l{lib_name}") + else: + # Find versioned .so files + versioned = sorted(lib_path.glob(f"lib{lib_name}.so.*")) + if versioned: + ldflags.append(f"-l:{versioned[-1].name}") + else: + ldflags.append(f"-l{lib_name}") # Fallback + + elif sys.platform == "win32": + ldflags = [f"/LIBPATH:{lib_path}"] + lib_names + + except Exception: + logger.warning( + "Failed to discover resources for CUDA package '%s'; continuing without it.", + pkg_name, + exc_info=True, + ) + + return include_path, ldflags + @staticmethod def is_available() -> bool: - """Check if CUDA is available in the current environment.""" + """Check if CUDA is available in the current environment. + + Returns + ------- + bool + True if PyTorch is installed and CUDA is available, False otherwise. + """ try: import torch except ImportError: @@ -105,6 +147,7 @@ def is_available() -> bool: return torch.cuda.is_available() def can_build(self, sol: Solution) -> bool: + """Check if this builder can handle the given solution.""" return sol.spec.language == SupportedLanguages.CUDA def _get_build_path(self, key: str) -> Path: @@ -122,13 +165,40 @@ def _get_build_path(self, key: str) -> Path: """ return get_fib_cache_path() / self._BUILD_DIR_NAME / key - def get_key(self, solution: Solution) -> str: - return create_package_name(solution, self._KEY_PREFIX) - def _filter_sources(self, source_paths: List[Path]) -> List[str]: + """Filter source files to include only C/C++/CUDA files. + + Parameters + ---------- + source_paths : List[Path] + List of all source file paths. + + Returns + ------- + List[str] + List of source file paths with C/C++/CUDA extensions. + """ return [str(path) for path in source_paths if path.suffix in _CPP_CUDA_EXTENSIONS] def _get_dependency_flags(self, sol: Solution) -> Tuple[List[str], List[str]]: + """Extract include paths and linker flags for solution dependencies. + + Parameters + ---------- + sol : Solution + The solution whose dependencies to process. + + Returns + ------- + Tuple[List[str], List[str]] + A tuple of (include_paths, ldflags) containing the include directories + and linker flags needed for the solution's dependencies. + + Raises + ------ + BuildError + If a required dependency is not available in the environment. + """ extra_include_paths = [] extra_ldflags = [] @@ -150,7 +220,18 @@ def _get_dependency_flags(self, sol: Solution) -> Tuple[List[str], List[str]]: return extra_include_paths, extra_ldflags def _get_cleaner(self, build_dir: Path) -> Callable[[], None]: - """Get a cleaner function for the build directory.""" + """Create a cleaner function that removes the build directory. + + Parameters + ---------- + build_dir : Path + The directory to delete. + + Returns + ------- + Callable[[], None] + A function that performs the cleanup. + """ def cleaner() -> None: shutil.rmtree(build_dir, ignore_errors=True) @@ -158,6 +239,29 @@ def cleaner() -> None: return cleaner def build(self, definition: Definition, solution: Solution) -> Runnable: + """Build a CUDA solution into a runnable. + + This method writes the solution sources to a build directory, compiles them + using PyTorch's cpp_extension.load(), and returns a callable wrapper. + + Parameters + ---------- + definition : Definition + The problem definition. + solution : Solution + The CUDA solution to build. + + Returns + ------- + Runnable + An executable wrapper around the compiled extension. + + Raises + ------ + BuildError + If the entry file is not a C/C++/CUDA file, compilation fails, or the + entry symbol is not found in the compiled extension. + """ from torch.utils.cpp_extension import load entry_file_extension = solution.get_entry_path().suffix @@ -168,8 +272,7 @@ def build(self, definition: Definition, solution: Solution) -> Runnable: ) symbol = solution.get_entry_symbol() - key = self.get_key(solution) - build_dir = self._get_build_path(key) + package_name, build_dir = self.get_package_name_and_build_path(solution) src_paths = write_sources_to_path(build_dir, solution.sources) src_paths = self._filter_sources(src_paths) @@ -180,7 +283,7 @@ def build(self, definition: Definition, solution: Solution) -> Runnable: try: ext = load( - name=key, + name=package_name, sources=src_paths, extra_include_paths=extra_include_paths, extra_ldflags=extra_ldflags, diff --git a/flashinfer_bench/compile/builders/triton_builder.py b/flashinfer_bench/compile/builders/triton_builder.py index 54ea3693..2496581d 100644 --- a/flashinfer_bench/compile/builders/triton_builder.py +++ b/flashinfer_bench/compile/builders/triton_builder.py @@ -1,19 +1,42 @@ +"""Builder for Triton GPU kernels.""" + from __future__ import annotations from typing import ClassVar +from flashinfer_bench.compile.builder import Builder from flashinfer_bench.compile.runnable import Runnable -from flashinfer_bench.compile.utils import create_package_name from flashinfer_bench.data import Definition, Solution, SupportedLanguages from .python_builder import PythonBuilder class TritonBuilder(PythonBuilder): + """Builder for Triton solutions. + + This builder extends PythonBuilder to handle Triton GPU kernels. Triton code + is Python-based, so the build process is similar to PythonBuilder, with the + main difference being the language tag in metadata. + """ + _KEY_PREFIX: ClassVar[str] = "fib_triton_" + """Prefix for cache keys to distinguish Triton solutions from pure Python ones.""" + + _BUILD_DIR_NAME: ClassVar[str] = "triton" + """Subdirectory under FIB_CACHE_PATH where build results are stored""" + + def __init__(self) -> None: + Builder.__init__(self, self._KEY_PREFIX, self._BUILD_DIR_NAME) @staticmethod def is_available() -> bool: + """Check if Triton is available in the current environment. + + Returns + ------- + bool + True if Triton is installed, False otherwise. + """ try: import triton except ImportError: @@ -21,12 +44,27 @@ def is_available() -> bool: return True def can_build(self, sol: Solution) -> bool: + """Check if this builder can handle the given solution.""" return sol.spec.language == SupportedLanguages.TRITON - def get_key(self, solution: Solution) -> str: - return create_package_name(solution, self._KEY_PREFIX) - def build(self, definition: Definition, solution: Solution) -> Runnable: + """Build a Triton solution into a runnable. + + This method delegates to PythonBuilder.build() and updates the build_type + in metadata to 'triton'. + + Parameters + ---------- + definition : Definition + The problem definition. + solution : Solution + The Triton solution to build. + + Returns + ------- + Runnable + An executable wrapper around the Triton kernel. + """ result = super().build(definition, solution) result.metadata.build_type = "triton" return result diff --git a/flashinfer_bench/compile/builders/tvm_ffi_builder.py b/flashinfer_bench/compile/builders/tvm_ffi_builder.py index 4df6a025..787d964b 100644 --- a/flashinfer_bench/compile/builders/tvm_ffi_builder.py +++ b/flashinfer_bench/compile/builders/tvm_ffi_builder.py @@ -4,15 +4,11 @@ import logging import shutil -from enum import Enum from pathlib import Path -from typing import Callable, Dict, List, Tuple - -import tvm_ffi -from tvm_ffi.utils import FileLock +from typing import Callable, ClassVar, List, Tuple from flashinfer_bench.compile.builder import Builder, BuildError -from flashinfer_bench.compile.runnable import Runnable, RunnableMetadata, TVMFFIRunnable +from flashinfer_bench.compile.runnable import Runnable, RunnableMetadata from flashinfer_bench.compile.utils import create_package_name, write_sources_to_path from flashinfer_bench.data import Definition, Solution, SupportedLanguages from flashinfer_bench.env import get_fib_cache_path @@ -44,24 +40,18 @@ class TVMFFIBuilder(Builder): >>> runnable.call_dest(x=input_tensor, output=output_tensor) # Destination-passing style """ - _BUILD_DIR_NAME = "tvm_ffi" + _KEY_PREFIX: ClassVar[str] = "tvm_ffi_" + """Prefix for cache keys to avoid collisions with other builders""" + + _BUILD_DIR_NAME: ClassVar[str] = "tvm_ffi" """Subdirectory under FIB_CACHE_PATH where build artifacts are stored""" - _LOCK_FILE_NAME = "flashinfer_bench_tvm_ffi_lock" + _LOCK_FILE_NAME: ClassVar[str] = "flashinfer_bench_tvm_ffi_lock" """File lock name for multi-process synchronization during compilation""" - _KEY_PREFIX = "tvm_ffi_" - """Prefix for cache keys to avoid collisions with other builders""" - def __init__(self) -> None: - """Initialize the TVMFFIBuilder. - - Sets up empty dictionaries for future extensibility with extra - include paths and linker flags (currently unused). - """ - super().__init__() - self._extra_include_paths: Dict[str, str] = {} - self._extra_ldflags: Dict[str, List[str]] = {} + """Initialize the TVMFFIBuilder.""" + super().__init__(self._KEY_PREFIX, self._BUILD_DIR_NAME) @staticmethod def is_available() -> bool: @@ -85,37 +75,10 @@ def can_build(self, sol: Solution) -> bool: bool True if solution language is CUDA (includes both .cu and .cpp files) """ - return sol.spec.language == SupportedLanguages.CUDA - - def get_key(self, solution: Solution) -> str: - """Generate unique cache key for a solution. - - Parameters - ---------- - solution : Solution - Solution to generate key for - - Returns - ------- - str - Unique key combining builder name and solution package name - """ - return create_package_name(solution, self._KEY_PREFIX) - - def _get_build_path(self, key: str) -> Path: - """Get the build directory path for a given cache key. - - Parameters - ---------- - key : str - Unique cache key for the solution - - Returns - ------- - Path - Directory path where build artifacts will be stored - """ - return get_fib_cache_path() / self._BUILD_DIR_NAME / key + return ( + sol.spec.language == SupportedLanguages.CUDA + or sol.spec.language == SupportedLanguages.CPP + ) def _check_sources(self, path: Path, key: str, sol: Solution) -> bool: """Check if the source code is vaild, and if the cached .so can be used by comparing source @@ -261,29 +224,32 @@ def build(self, definition: Definition, solution: Solution) -> Runnable: Returns ------- Runnable - TVMFFIRunnable that can be called with input tensors + A runnable wrapper around the compiled TVM-FFI module that supports both + value-returning style (via __call__) and destination-passing style (via call_dps) Raises ------ BuildError If compilation fails, module loading fails, or entry point is invalid """ - key = self.get_key(solution) - build_path = self._get_build_path(key) + import tvm_ffi + from tvm_ffi.utils import FileLock + + package_name, build_path = self.get_package_name_and_build_path(solution) entry_symbol = self._get_entry_symbol(solution) - can_use_cached = self._check_sources(build_path, key, solution) + can_use_cached = self._check_sources(build_path, package_name, solution) # Check if cached .so can be used. If not, build the solution. # This check and build are thread-safe through the FileLock if can_use_cached: - output_lib_path = str(build_path / f"{key}.so") + output_lib_path = str(build_path / f"{package_name}.so") else: # Ensure build directory exists before creating file lock build_path.mkdir(parents=True, exist_ok=True) with FileLock(build_path / self._LOCK_FILE_NAME): # Double-check after acquiring lock (another process may have built it) - if self._check_sources(build_path, key, solution): - output_lib_path = str(build_path / f"{key}.so") + if self._check_sources(build_path, package_name, solution): + output_lib_path = str(build_path / f"{package_name}.so") else: src_paths = write_sources_to_path(build_path, solution.sources) cpp_files, cuda_files = self._filter_sources(src_paths) @@ -291,7 +257,7 @@ def build(self, definition: Definition, solution: Solution) -> Runnable: try: # Compile sources to shared library output_lib_path = tvm_ffi.cpp.build( - name=key, + name=package_name, cpp_files=cpp_files, cuda_files=cuda_files, extra_include_paths=extra_include_paths, @@ -315,7 +281,7 @@ def build(self, definition: Definition, solution: Solution) -> Runnable: solution=solution.name, misc={ "definition": definition, - "key": key, + "package_name": package_name, "symbol": entry_symbol, "binary": output_lib_path, }, diff --git a/flashinfer_bench/compile/registry.py b/flashinfer_bench/compile/registry.py index 4e4468f9..f5f84803 100644 --- a/flashinfer_bench/compile/registry.py +++ b/flashinfer_bench/compile/registry.py @@ -1,3 +1,5 @@ +"""Builder registry for dispatching and caching builds.""" + from __future__ import annotations from typing import ClassVar, Dict, List, Type @@ -9,27 +11,61 @@ from .runnable import Runnable _BUILDER_PRIORITY: List[Type[Builder]] = [TritonBuilder, PythonBuilder, TVMFFIBuilder, TorchBuilder] -"""Contains all builders in the order of priority.""" +"""Builder types in priority order for automatic selection. + +When building a solution, the registry tries builders in this order and uses the first +one that reports it can build the solution. The order reflects typical preferences: +1. TritonBuilder - for Triton GPU kernels +2. PythonBuilder - for pure Python implementations +3. TVMFFIBuilder - for CUDA with TVM-FFI bindings (preferred for CUDA) +4. TorchBuilder - for CUDA with PyTorch extensions (fallback for CUDA) +""" class BuilderRegistry: - """Registry that dispatches to the first capable builder.""" + """Central registry for managing and dispatching builders. + + The BuilderRegistry maintains a list of available builders and automatically selects + the appropriate one for each solution. It also provides caching to avoid redundant + builds of the same solution. + + This class follows the singleton pattern - use get_instance() to obtain the shared + registry instance. + """ _instance: ClassVar["BuilderRegistry" | None] = None """Singleton instance of the BuilderRegistry.""" _builders: List[Builder] - """List of builders in the order of priority.""" + """List of available builders in priority order.""" _cache: Dict[str, Runnable] - """Cache of built runnables.""" + """Cache mapping solution hashes to built runnables.""" def __init__(self, builders: List[Builder]) -> None: + """Initialize the registry with a list of builders. + + Parameters + ---------- + builders : List[Builder] + List of builder instances to use. Must contain at least one builder. + + Raises + ------ + ValueError + If the builders list is empty. + """ if len(builders) == 0: raise ValueError("BuilderRegistry requires at least one builder") self._builders = list(builders) self._cache: Dict[str, Runnable] = {} def clear(self) -> None: + """Clear the cache and cleanup all built runnables. + + This method calls cleanup() on all cached runnables to release resources, + then clears the cache. Cleanup errors are caught and ignored to ensure + all runnables are processed. + """ for runnable in self._cache.values(): try: runnable.cleanup() @@ -39,6 +75,17 @@ def clear(self) -> None: @classmethod def get_instance(cls) -> "BuilderRegistry": + """Get the singleton registry instance. + + On first call, this method initializes the registry by instantiating all + available builders (those whose is_available() returns True) in priority order. + Subsequent calls return the same instance. + + Returns + ------- + BuilderRegistry + The shared registry instance. + """ if cls._instance is None: builders = [] for builder_type in _BUILDER_PRIORITY: @@ -48,6 +95,30 @@ def get_instance(cls) -> "BuilderRegistry": return cls._instance def build(self, defn: Definition, sol: Solution) -> Runnable: + """Build a solution into a runnable, using cache if available. + + This method first checks if the solution has already been built (by comparing + its hash). If not, it tries each registered builder in priority order until + one reports it can build the solution. The resulting runnable is cached for + future use. + + Parameters + ---------- + defn : Definition + The problem definition specifying the expected interface. + sol : Solution + The solution to build. + + Returns + ------- + Runnable + An executable wrapper around the built solution. + + Raises + ------ + BuildError + If no registered builder can build this solution, or if the build fails. + """ hash = sol.hash() if hash in self._cache: return self._cache[hash] @@ -61,6 +132,26 @@ def build(self, defn: Definition, sol: Solution) -> Runnable: raise BuildError(f"No registered builder can build solution '{sol.name}'") def build_reference(self, defn: Definition) -> Runnable: + """Build the reference implementation for a definition. + + This is a convenience method that creates a pseudo-solution from the definition's + reference code and builds it using the standard build() method. + + Parameters + ---------- + defn : Definition + The definition containing the reference implementation. + + Returns + ------- + Runnable + An executable wrapper around the reference implementation. + + Raises + ------ + BuildError + If the reference implementation cannot be built. + """ pseudo = Solution( name=f"{defn.name}__reference", definition=defn.name, @@ -74,3 +165,16 @@ def build_reference(self, defn: Definition) -> Runnable: description="reference", ) return self.build(defn, pseudo) + + +def get_builder_registry() -> BuilderRegistry: + """Get the singleton builder registry instance. + + This is a convenience function that delegates to BuilderRegistry.get_instance(). + + Returns + ------- + BuilderRegistry + The shared registry instance. + """ + return BuilderRegistry.get_instance() diff --git a/flashinfer_bench/compile/runnable.py b/flashinfer_bench/compile/runnable.py index 9016341c..1fabf1b8 100644 --- a/flashinfer_bench/compile/runnable.py +++ b/flashinfer_bench/compile/runnable.py @@ -1,3 +1,5 @@ +"""Runnable wrapper for compiled solutions.""" + from __future__ import annotations from typing import Any, Callable, Dict, Literal, Optional @@ -12,29 +14,37 @@ class RunnableMetadata(BaseModel): - """Metadata about the runnable.""" + """Metadata about a runnable implementation. + + This class stores information about how a runnable was built, including the + builder type, source definition/solution, and additional builder-specific data. + """ build_type: BuildType - """The type of build that produced this runnable.""" + """The type of build that produced this runnable (e.g., 'python', 'torch', 'triton', 'tvm_ffi').""" definition: str - """The definition that was used to build this runnable.""" + """Name of the definition that specifies the expected interface.""" solution: str - """The solution that was used to build this runnable.""" + """Name of the solution that was compiled into this runnable.""" misc: Dict[str, Any] - """Miscellaneous metadata about the runnable.""" + """Miscellaneous metadata about the runnable. Contents vary by builder type.""" class Runnable: - """A callable that is compiled from a solution. The runnable contains a callable, metadata, - and a closer function.""" + """An executable wrapper around a compiled solution. + + A Runnable encapsulates a callable function along with metadata about how it was built + and a cleanup function to release resources. It provides a uniform interface for + executing solutions regardless of the build system or language used. + """ metadata: RunnableMetadata - """The metadata for the runnable.""" + """Metadata about the build process and source solution.""" _callable: Callable[..., Any] - """The callable that is wrapped by the runnable.""" - _closer: Optional[Callable[[], None]] - """The closer function for the runnable.""" + """The underlying callable function.""" + _cleaner: Optional[Callable[[], None]] + """Optional cleanup function to release build artifacts and resources.""" def __init__( self, @@ -58,22 +68,22 @@ def __init__( self._cleaner = cleaner def __call__(self, **kwargs: Any) -> Any: - """ - Call the underlying function, and return the result. If the result is a single-element - tuple, unpack it. + """Execute the runnable with keyword arguments. + + This method calls the underlying compiled function with the provided inputs. + If the function returns a single-element tuple, it is automatically unpacked + to a scalar value for convenience. Parameters ---------- - args : Any - The positional arguments to pass to the underlying function. kwargs : Any - The keyword arguments to pass to the underlying function. + Keyword arguments matching the definition's input specification. Returns ------- Any - The result of the underlying function. If the result is a single-element tuple, - unpack it to a scalar value. + The result of the underlying function. Single-element tuples are unpacked + to scalar values. """ ret = self._callable(**kwargs) if isinstance(ret, tuple) and len(ret) == 1: @@ -83,25 +93,32 @@ def __call__(self, **kwargs: Any) -> Any: def call_dps(self, **kwargs: Any) -> Any: """Call a destination-passing style (DPS) function in value-returning style. - This method assumes the callable is destination-passing style:: - - function(**kwargs, **output_tensors) -> None + Some solutions use the destination-passing style, + where output tensors are passed as arguments and the function modifies them in-place:: - And calling this method will call the DPS function in value-returning style: + function(**input_tensors, **output_tensors) -> None - runnable.call_dps(**kwargs) -> output_tensors + This method provides a value-returning interface by automatically allocating output + tensors based on the definition, calling the DPS function, and returning the outputs:: - It will internally allocate output tensors, call the callable with the provided inputs - and allocated output tensors, and return the results. + result = runnable.call_dps(**input_tensors) # -> output_tensors Parameters ---------- kwargs : Any - The keyword arguments to pass to the underlying function. + Keyword arguments for input tensors matching the definition's input specification. Returns ------- Any + The output tensor(s). Single outputs are returned as-is, multiple outputs are + returned as a tuple, and empty outputs return None. + + Raises + ------ + ValueError + If the metadata does not contain the full definition object needed for + output tensor allocation. """ import torch @@ -142,7 +159,12 @@ def call_dps(self, **kwargs: Any) -> Any: return results def cleanup(self) -> None: - """Clean up the build artifacts/resources.""" + """Clean up build artifacts and release resources. + + This method calls the cleaner function if one was provided during construction. + It is idempotent: calling it multiple times is safe and has no additional effect + after the first call. + """ if self._closer: try: self._closer() diff --git a/flashinfer_bench/compile/utils.py b/flashinfer_bench/compile/utils.py index ff595d5f..1f98aa74 100644 --- a/flashinfer_bench/compile/utils.py +++ b/flashinfer_bench/compile/utils.py @@ -1,6 +1,7 @@ +"""Utility functions for building solutions.""" + from __future__ import annotations -import hashlib import re from pathlib import Path from typing import List @@ -9,19 +10,29 @@ def write_sources_to_path(path: Path, sources: List[SourceFile]) -> List[Path]: - """Write a list of source files to the given directory. + """Write source files to a directory and return their paths. - Creates parent directories as needed for files in subdirectories. - Overwrites files unconditionally (caller already determined a full build is needed). - Each source file should not contain parent directory traversal ("..") or absolute paths, and - should be unique. + This function writes all source files from a solution to a specified directory, + creating subdirectories as needed. It performs security checks to prevent path + traversal attacks and absolute path injection. Parameters ---------- path : Path - The directory path to write the source files to. - sources : list[SourceFile] - The list of source files to write. + The root directory where source files will be written. + sources : List[SourceFile] + The list of source files to write. Each file's path must be relative and + not contain parent directory references (".."). + + Returns + ------- + List[Path] + List of absolute paths to the written files. + + Raises + ------ + AssertionError + If any source file has an absolute path or contains path traversal. """ path.mkdir(parents=True, exist_ok=True) paths: List[Path] = [] @@ -45,20 +56,32 @@ def write_sources_to_path(path: Path, sources: List[SourceFile]) -> List[Path]: def create_package_name(sol: Solution, prefix: str = "") -> str: - """Create a package name for a solution. The name is created by normalizing the solution name - and hashing the sources. + """Generate a unique package name for a solution. + + The package name is constructed from three parts: + 1. A prefix (typically identifying the builder) + 2. The normalized solution name (alphanumeric and underscores only) + 3. A 6-character hash of the solution content + + This ensures the package name is both human-readable and uniquely identifies + the solution's content. Parameters ---------- sol : Solution The solution to create a package name for. - prefix : str - The prefix to add to the package name. + prefix : str, optional + The prefix to prepend to the package name. Default is empty string. Returns ------- str - The package name for the solution. + A unique package name in the format: {prefix}{normalized_name}_{hash}. + + Examples + -------- + >>> create_package_name(solution, "fib_python_") + 'fib_python_rmsnorm_v1_a3f2b1' """ # Normalize the solution name s = re.sub(r"[^0-9a-zA-Z_]", "_", sol.name) diff --git a/flashinfer_bench/data/solution.py b/flashinfer_bench/data/solution.py index eb3aa0d8..d821f830 100644 --- a/flashinfer_bench/data/solution.py +++ b/flashinfer_bench/data/solution.py @@ -21,6 +21,8 @@ class SupportedLanguages(str, Enum): """Python programming language.""" TRITON = "triton" """Triton GPU programming language.""" + CPP = "cpp" + """Pure C++ source code.""" CUDA = "cuda" """CUDA C++ programming language.""" @@ -141,22 +143,29 @@ def _validate_source_path_entry_point(self) -> "Solution": return self def get_entry_path(self) -> Path: - """Get the path to the entry source file. + """Extract the file path from the entry point specification. + + The entry point format is '{file_path}::{function_name}', and this method + returns the file path component as a Path object. Returns ------- - str - The path to the entry source file. + Path + The relative path to the entry source file (e.g., 'main.py', 'src/kernel.cu'). """ return Path(self.spec.entry_point.split("::")[0]) def get_entry_symbol(self) -> str: - """Extract function symbol from entry_point. + """Extract the function/symbol name from the entry point specification. + + The entry point format is '{file_path}::{function_name}', and this method + returns the function name component. This is the symbol that builders will + look up in the compiled module or imported Python module. Returns ------- str - The function symbol name to be loaded from the compiled module + The function or symbol name to be loaded (e.g., 'run', 'forward', 'kernel'). """ return self.spec.entry_point.split("::")[-1] @@ -175,12 +184,20 @@ def get_entry_source(self) -> Optional[SourceFile]: return None def hash(self) -> str: - """Hash the solution. It returns the SHA1 hash of the solution. + """Compute a deterministic hash of the solution content. + + The hash is computed from all fields that affect the solution's behavior: + name, definition, language, entry point, dependencies, and all source file + paths and contents. This ensures that any meaningful change to the solution + results in a different hash. + + The hash is used for caching build artifacts - solutions with the same hash + can reuse the same cached build result. Returns ------- str - The hash of the solution. + A SHA1 hash (40 hex characters) uniquely identifying this solution's content. """ h = hashlib.sha1() for s in (