|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import sys |
4 | 3 | import os |
5 | 4 | import ctypes |
6 | | -import functools |
7 | 5 | import pathlib |
8 | 6 |
|
9 | 7 | from typing import ( |
10 | | - Any, |
11 | 8 | Callable, |
12 | | - List, |
13 | 9 | Union, |
14 | 10 | NewType, |
15 | 11 | Optional, |
16 | 12 | TYPE_CHECKING, |
17 | | - TypeVar, |
18 | | - Generic, |
19 | 13 | ) |
20 | | -from typing_extensions import TypeAlias |
21 | 14 |
|
| 15 | +from llama_cpp._ctypes_extensions import ( |
| 16 | + load_shared_library, |
| 17 | + byref, |
| 18 | + ctypes_function_for_shared_library, |
| 19 | +) |
22 | 20 |
|
23 | | -# Load the library |
24 | | -def _load_shared_library(lib_base_name: str): |
25 | | - # Construct the paths to the possible shared library names |
26 | | - _base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" |
27 | | - # Searching for the library in the current directory under the name "libllama" (default name |
28 | | - # for llamacpp) and "llama" (default name for this repo) |
29 | | - _lib_paths: List[pathlib.Path] = [] |
30 | | - # Determine the file extension based on the platform |
31 | | - if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"): |
32 | | - _lib_paths += [ |
33 | | - _base_path / f"lib{lib_base_name}.so", |
34 | | - ] |
35 | | - elif sys.platform == "darwin": |
36 | | - _lib_paths += [ |
37 | | - _base_path / f"lib{lib_base_name}.so", |
38 | | - _base_path / f"lib{lib_base_name}.dylib", |
39 | | - ] |
40 | | - elif sys.platform == "win32": |
41 | | - _lib_paths += [ |
42 | | - _base_path / f"{lib_base_name}.dll", |
43 | | - _base_path / f"lib{lib_base_name}.dll", |
44 | | - ] |
45 | | - else: |
46 | | - raise RuntimeError("Unsupported platform") |
47 | | - |
48 | | - if "LLAMA_CPP_LIB" in os.environ: |
49 | | - lib_base_name = os.environ["LLAMA_CPP_LIB"] |
50 | | - _lib = pathlib.Path(lib_base_name) |
51 | | - _base_path = _lib.parent.resolve() |
52 | | - _lib_paths = [_lib.resolve()] |
53 | | - |
54 | | - cdll_args = dict() # type: ignore |
55 | | - |
56 | | - # Add the library directory to the DLL search path on Windows (if needed) |
57 | | - if sys.platform == "win32": |
58 | | - os.add_dll_directory(str(_base_path)) |
59 | | - os.environ["PATH"] = str(_base_path) + os.pathsep + os.environ["PATH"] |
60 | | - |
61 | | - if sys.platform == "win32" and sys.version_info >= (3, 8): |
62 | | - os.add_dll_directory(str(_base_path)) |
63 | | - if "CUDA_PATH" in os.environ: |
64 | | - os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin")) |
65 | | - os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib")) |
66 | | - if "HIP_PATH" in os.environ: |
67 | | - os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin")) |
68 | | - os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib")) |
69 | | - cdll_args["winmode"] = ctypes.RTLD_GLOBAL |
70 | | - |
71 | | - # Try to load the shared library, handling potential errors |
72 | | - for _lib_path in _lib_paths: |
73 | | - if _lib_path.exists(): |
74 | | - try: |
75 | | - return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore |
76 | | - except Exception as e: |
77 | | - raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}") |
78 | | - |
79 | | - raise FileNotFoundError( |
80 | | - f"Shared library with base name '{lib_base_name}' not found" |
| 21 | +if TYPE_CHECKING: |
| 22 | + from llama_cpp._ctypes_extensions import ( |
| 23 | + CtypesCData, |
| 24 | + CtypesArray, |
| 25 | + CtypesPointer, |
| 26 | + CtypesVoidPointer, |
| 27 | + CtypesRef, |
| 28 | + CtypesPointerOrRef, |
| 29 | + CtypesFuncPointer, |
81 | 30 | ) |
82 | 31 |
|
83 | 32 |
|
84 | 33 | # Specify the base name of the shared library to load |
85 | 34 | _lib_base_name = "llama" |
86 | | - |
| 35 | +_override_base_path = os.environ.get("LLAMA_CPP_LIB_PATH") |
| 36 | +_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _override_base_path is None else pathlib.Path(_override_base_path) |
87 | 37 | # Load the library |
88 | | -_lib = _load_shared_library(_lib_base_name) |
89 | | - |
90 | | - |
91 | | -# ctypes sane type hint helpers |
92 | | -# |
93 | | -# - Generic Pointer and Array types |
94 | | -# - PointerOrRef type with a type hinted byref function |
95 | | -# |
96 | | -# NOTE: Only use these for static type checking not for runtime checks |
97 | | -# no good will come of that |
98 | | - |
99 | | -if TYPE_CHECKING: |
100 | | - CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore |
101 | | - |
102 | | - CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore |
103 | | - |
104 | | - CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore |
105 | | - |
106 | | - CtypesVoidPointer: TypeAlias = ctypes.c_void_p |
107 | | - |
108 | | - class CtypesRef(Generic[CtypesCData]): |
109 | | - pass |
110 | | - |
111 | | - CtypesPointerOrRef: TypeAlias = Union[ |
112 | | - CtypesPointer[CtypesCData], CtypesRef[CtypesCData] |
113 | | - ] |
114 | | - |
115 | | - CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore |
116 | | - |
117 | | -F = TypeVar("F", bound=Callable[..., Any]) |
118 | | - |
119 | | - |
120 | | -def ctypes_function_for_shared_library(lib: ctypes.CDLL): |
121 | | - def ctypes_function( |
122 | | - name: str, argtypes: List[Any], restype: Any, enabled: bool = True |
123 | | - ): |
124 | | - def decorator(f: F) -> F: |
125 | | - if enabled: |
126 | | - func = getattr(lib, name) |
127 | | - func.argtypes = argtypes |
128 | | - func.restype = restype |
129 | | - functools.wraps(f)(func) |
130 | | - return func |
131 | | - else: |
132 | | - return f |
133 | | - |
134 | | - return decorator |
135 | | - |
136 | | - return ctypes_function |
137 | | - |
| 38 | +_lib = load_shared_library(_lib_base_name, _base_path) |
138 | 39 |
|
139 | 40 | ctypes_function = ctypes_function_for_shared_library(_lib) |
140 | 41 |
|
141 | 42 |
|
142 | | -def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]: |
143 | | - """Type-annotated version of ctypes.byref""" |
144 | | - ... |
145 | | - |
146 | | - |
147 | | -byref = ctypes.byref # type: ignore |
148 | | - |
149 | 43 | # from ggml.h |
150 | 44 | # // NOTE: always add types at the end of the enum to keep backward compatibility |
151 | 45 | # enum ggml_type { |
|
0 commit comments