Skip to content

Commit 6f8ec8b

Browse files
committed
Sync llama/mtmd API change, support clip flash-attn
1 parent ab70ead commit 6f8ec8b

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2811,17 +2811,18 @@ def _init_mtmd_context(self, llama_model: llama.Llama):
28112811

28122812
with suppress_stdout_stderr(disable=self.verbose):
28132813
# Get default parameters
2814-
ctx_params = self._mtmd_cpp.mtmd_context_params_default()
2815-
ctx_params.use_gpu = True # TODO: Make this configurable
2816-
ctx_params.print_timings = self.verbose
2817-
ctx_params.n_threads = llama_model.n_threads
2818-
ctx_params.verbosity = 2 if self.verbose else 0 # GGML_LOG_LEVEL_INFO = 2
2814+
mctx_params = self._mtmd_cpp.mtmd_context_params_default()
2815+
mctx_params.use_gpu = True # TODO: Make this configurable
2816+
mctx_params.print_timings = self.verbose
2817+
mctx_params.n_threads = llama_model.n_threads
2818+
mctx_params.verbosity = 2 if self.verbose else 0 # GGML_LOG_LEVEL_INFO = 2
2819+
mctx_params.flash_attn_type = self._mtmd_cpp.clip_flash_attn_type.CLIP_FLASH_ATTN_TYPE_AUTO
28192820

28202821
# Initialize mtmd context
28212822
self.mtmd_ctx = self._mtmd_cpp.mtmd_init_from_file(
28222823
self.clip_model_path.encode(),
28232824
llama_model.model,
2824-
ctx_params
2825+
mctx_params
28252826
)
28262827

28272828
if self.mtmd_ctx is None:

llama_cpp/llama_cpp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1401,12 +1401,18 @@ def llama_supports_gpu_offload() -> bool:
14011401
def llama_supports_rpc() -> bool:
14021402
...
14031403

1404-
1404+
# // NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
1405+
# // In some cases the requested values via llama_context_params may differ from the actual values used by the context
14051406
# LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
14061407
@ctypes_function("llama_n_ctx", [llama_context_p_ctypes], ctypes.c_uint32)
14071408
def llama_n_ctx(ctx: llama_context_p, /) -> int:
14081409
...
14091410

1411+
# LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
1412+
@ctypes_function("llama_n_ctx_seq", [llama_context_p_ctypes], ctypes.c_uint32)
1413+
def llama_n_ctx_seq(ctx: llama_context_p, /) -> int:
1414+
...
1415+
14101416

14111417
# LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
14121418
@ctypes_function("llama_n_batch", [llama_context_p_ctypes], ctypes.c_uint32)

llama_cpp/mtmd_cpp.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,44 @@ class mtmd_input_text(Structure):
108108
mtmd_input_text_p = NewType("mtmd_input_text_p", int)
109109
mtmd_input_text_p_ctypes = POINTER(mtmd_input_text)
110110

111+
# enum clip_flash_attn_type {
112+
# CLIP_FLASH_ATTN_TYPE_AUTO = -1,
113+
# CLIP_FLASH_ATTN_TYPE_DISABLED = 0,
114+
# CLIP_FLASH_ATTN_TYPE_ENABLED = 1,
115+
# };
116+
class clip_flash_attn_type (enum.IntEnum):
117+
CLIP_FLASH_ATTN_TYPE_AUTO = -1
118+
CLIP_FLASH_ATTN_TYPE_DISABLED = 0
119+
CLIP_FLASH_ATTN_TYPE_ENABLED = 1
120+
121+
# struct clip_context_params {
122+
# bool use_gpu;
123+
# enum ggml_log_level verbosity;
124+
# enum clip_flash_attn_type flash_attn_type;
125+
# int image_min_tokens;
126+
# int image_max_tokens;
127+
# };
128+
class clip_context_params(Structure):
129+
_fields_ = [
130+
("use_gpu", c_bool),
131+
("verbosity", c_int),
132+
("flash_attn_type", c_int),
133+
("image_min_tokens", c_int),
134+
("image_max_tokens", c_int),
135+
]
136+
111137
# struct mtmd_context_params {
112138
# bool use_gpu;
113139
# bool print_timings;
114140
# int n_threads;
115141
# enum ggml_log_level verbosity;
116142
# const char * image_marker; // deprecated, use media_marker instead
117143
# const char * media_marker;
144+
# enum llama_flash_attn_type flash_attn_type;
145+
146+
# // limit number of image tokens, only for vision models with dynamic resolution
147+
# int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
148+
# int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
118149
# };
119150
class mtmd_context_params(Structure):
120151
_fields_ = [
@@ -124,6 +155,9 @@ class mtmd_context_params(Structure):
124155
("verbosity", c_int),
125156
("image_marker", c_char_p),
126157
("media_marker", c_char_p),
158+
("flash_attn_type", c_int),
159+
("image_min_tokens", c_int),
160+
("image_max_tokens", c_int),
127161
]
128162

129163
mtmd_context_params_p = NewType("mtmd_context_params_p", int)

0 commit comments

Comments
 (0)