Skip to content

Commit ffff841

Browse files
committed
migrate clip to mtmd
1 parent 03ce53b commit ffff841

File tree

4 files changed

+279
-57
lines changed

4 files changed

+279
-57
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 52 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2716,7 +2716,7 @@ class Llava15ChatHandler:
27162716
"{% endif %}"
27172717
)
27182718

2719-
def __init__(self, clip_model_path: str, verbose: bool = True):
2719+
def __init__(self, clip_model_path: str, llama_model: llama.Llama, verbose: bool = True):
27202720
import llama_cpp.mtmd_cpp as mtmd_cpp
27212721

27222722
self.clip_model_path = clip_model_path
@@ -3630,62 +3630,59 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
36303630
return split_text
36313631

36323632
def eval_image(self, llama: llama.Llama, image_url: str):
3633-
import llama_cpp
3633+
image_bytes = self.load_image(image_url)
3634+
3635+
# Create bitmap manager if not exists
3636+
if self._bitmap_manager is None:
3637+
self._bitmap_manager = self._mtmd_cpp.BitmapManager()
3638+
3639+
# Create bitmap from bytes
3640+
if not self._bitmap_manager.add_from_memory(self.clip_ctx, image_bytes):
3641+
raise ValueError("Failed to create bitmap from image bytes")
3642+
3643+
# Create input chunks for the bitmap
3644+
chunks = self._mtmd_cpp.mtmd_input_chunks_init()
3645+
if chunks is None:
3646+
raise ValueError("Failed to create input chunks")
3647+
3648+
# Create input text with media marker
3649+
# Get media marker from context params
3650+
params = self._mtmd_cpp.mtmd_context_params_default()
3651+
text = self._mtmd_cpp.mtmd_input_text()
3652+
text.text = params.media_marker if params.media_marker else self._mtmd_cpp.mtmd_default_marker()
3653+
text.add_special = False
3654+
text.parse_special = True
3655+
3656+
# Tokenize with bitmap
3657+
if self._mtmd_cpp.mtmd_tokenize(self.clip_ctx, chunks, text, self._bitmap_manager.c_ptr(), len(self._bitmap_manager.entries)) != 0:
3658+
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3659+
raise ValueError("Failed to tokenize image")
3660+
3661+
# Get new n_past after evaluation
3662+
n_past = ctypes.c_int(llama.n_tokens)
3663+
n_past_p = ctypes.pointer(n_past)
36343664

3635-
n_tokens = 256
3636-
if llama.n_tokens + n_tokens > llama.n_ctx():
3637-
raise ValueError(
3638-
f"Prompt exceeds n_ctx: {llama.n_tokens + n_tokens} > {llama.n_ctx()}"
3639-
)
3665+
# Evaluate chunks
3666+
if self._mtmd_cpp.mtmd_helper_eval_chunks(
3667+
self.clip_ctx,
3668+
llama.ctx,
3669+
chunks,
3670+
llama.n_tokens,
3671+
0, # seq_id
3672+
llama.n_batch,
3673+
True, # logits_last
3674+
n_past_p
3675+
) != 0:
3676+
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3677+
raise ValueError("Failed to evaluate chunks")
3678+
3679+
# Update n_tokens
3680+
llama.input_ids[llama.n_tokens : n_past.value] = -1
3681+
llama.n_tokens = n_past.value
36403682

3641-
img_bytes = self.load_image(image_url)
3642-
img_u8_p = self._mtmd_cpp.clip_image_u8_init()
3643-
if not self._mtmd_cpp.clip_image_load_from_bytes(
3644-
ctypes.create_string_buffer(img_bytes, len(img_bytes)),
3645-
ctypes.c_size_t(len(img_bytes)),
3646-
img_u8_p,
3647-
):
3648-
self._mtmd_cpp.clip_image_u8_free(img_u8_p)
3649-
raise ValueError("Failed to load image.")
3650-
3651-
img_f32_p = self._mtmd_cpp.clip_image_f32_batch_init()
3652-
if not self._mtmd_cpp.clip_image_preprocess(self.clip_ctx, img_u8_p, img_f32_p):
3653-
self._mtmd_cpp.clip_image_f32_batch_free(img_f32_p)
3654-
self._mtmd_cpp.clip_image_u8_free(img_u8_p)
3655-
raise ValueError("Failed to preprocess image.")
3656-
3657-
n_embd = llama_cpp.llama_model_n_embd(llama._model.model)
3658-
embed = (ctypes.c_float * (n_tokens * n_embd))()
3659-
if not self._mtmd_cpp.clip_image_batch_encode(self.clip_ctx, llama.n_threads, img_f32_p, embed):
3660-
self._mtmd_cpp.clip_image_f32_batch_free(img_f32_p)
3661-
self._mtmd_cpp.clip_image_u8_free(img_u8_p)
3662-
raise ValueError("Failed to encode image.")
3663-
3664-
self._mtmd_cpp.clip_image_f32_batch_free(img_f32_p)
3665-
self._mtmd_cpp.clip_image_u8_free(img_u8_p)
3666-
llama_cpp.llama_set_causal_attn(llama.ctx, False)
3667-
3668-
seq_id_0 = (ctypes.c_int32 * 1)()
3669-
seq_ids = (ctypes.POINTER(ctypes.c_int32) * (n_tokens + 1))()
3670-
for i in range(n_tokens):
3671-
seq_ids[i] = seq_id_0
3672-
3673-
batch = llama_cpp.llama_batch()
3674-
batch.n_tokens = n_tokens
3675-
batch.token = None
3676-
batch.embd = embed
3677-
batch.pos = (ctypes.c_int32 * n_tokens)(*[i + llama.n_tokens for i in range(n_tokens)])
3678-
batch.seq_id = seq_ids
3679-
batch.n_seq_id = (ctypes.c_int32 * n_tokens)(*([1] * n_tokens))
3680-
batch.logits = (ctypes.c_int8 * n_tokens)()
3681-
3682-
if llama_cpp.llama_decode(llama.ctx, batch):
3683-
raise ValueError("Failed to decode image.")
3684-
3685-
llama_cpp.llama_set_causal_attn(llama.ctx, True)
3686-
# Required to avoid issues with hf tokenizer
3687-
llama.input_ids[llama.n_tokens : llama.n_tokens + n_tokens] = -1
3688-
llama.n_tokens += n_tokens
3683+
# Cleanup
3684+
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3685+
self._bitmap_manager.clear()
36893686

36903687

36913688
def _accumulate_chunks(

llama_cpp/mtmd_cpp.py

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

33
import os
4+
import ctypes
45
from ctypes import (
56
c_bool,
67
c_char_p,
78
c_int,
9+
c_int32,
810
c_uint8,
911
c_uint32,
1012
c_float,
@@ -17,6 +19,7 @@
1719
)
1820
import pathlib
1921
from typing import (
22+
List,
2023
Union,
2124
NewType,
2225
Optional,
@@ -31,19 +34,161 @@
3134
)
3235

3336
if TYPE_CHECKING:
37+
from llama_cpp.llama_types import (
38+
llama_token,
39+
llama_pos,
40+
)
3441
from llama_cpp._ctypes_extensions import (
3542
CtypesArray,
43+
CtypesPointer,
3644
)
3745

46+
# Define input text structure
47+
class mtmd_input_text(Structure):
48+
_fields_ = [
49+
("text", c_char_p),
50+
("add_special", c_bool),
51+
("parse_special", c_bool),
52+
]
53+
54+
# Define context parameters structure
55+
class mtmd_context_params(Structure):
56+
_fields_ = [
57+
("use_gpu", c_bool),
58+
("print_timings", c_bool),
59+
("n_threads", c_int),
60+
("verbosity", c_int),
61+
("image_marker", c_char_p), # const char*
62+
("media_marker", c_char_p), # const char*
63+
]
64+
65+
# Define input chunk type enum
66+
mtmd_input_chunk_type = c_int
67+
(
68+
MTMD_INPUT_CHUNK_TYPE_TEXT,
69+
MTMD_INPUT_CHUNK_TYPE_IMAGE,
70+
MTMD_INPUT_CHUNK_TYPE_AUDIO,
71+
) = (0, 1, 2)
72+
73+
# Define slice template enum
74+
mtmd_slice_tmpl = c_int
75+
(
76+
MTMD_SLICE_TMPL_NONE,
77+
MTMD_SLICE_TMPL_MINICPMV_2_5,
78+
MTMD_SLICE_TMPL_MINICPMV_2_6,
79+
MTMD_SLICE_TMPL_LLAMA4,
80+
) = (0, 1, 2, 3)
81+
82+
# Define whisper filters structure
83+
class whisper_filters(Structure):
84+
_fields_ = [
85+
("n_mel", c_int),
86+
]
87+
88+
# Define mtmd_context structure
89+
class mtmd_context(Structure):
90+
_fields_ = [
91+
("ctx_v", c_void_p), # clip_ctx*
92+
("ctx_a", c_void_p), # clip_ctx*
93+
("text_model", c_void_p), # const llama_model*
94+
("image_embd_v", POINTER(c_float)), # std::vector<float>
95+
("print_timings", c_bool),
96+
("n_threads", c_int),
97+
("media_marker", c_char_p), # std::string
98+
("n_embd_text", c_int),
99+
("img_beg", c_char_p), # std::string
100+
("img_end", c_char_p), # std::string
101+
("aud_beg", c_char_p), # std::string
102+
("aud_end", c_char_p), # std::string
103+
("slice_tmpl", c_int), # mtmd_slice_tmpl
104+
("tok_ov_img_start", llama_cpp.llama_token),
105+
("tok_ov_img_end", llama_cpp.llama_token),
106+
("tok_slices_start", llama_cpp.llama_token),
107+
("tok_slices_end", llama_cpp.llama_token),
108+
("tok_sli_img_start", llama_cpp.llama_token),
109+
("tok_sli_img_end", llama_cpp.llama_token),
110+
("tok_sli_img_mid", llama_cpp.llama_token),
111+
("tok_row_end", llama_cpp.llama_token),
112+
("tok_row_end_trail", c_bool),
113+
("ov_img_first", c_bool),
114+
("use_mrope", c_bool),
115+
("w_filters", whisper_filters),
116+
]
117+
118+
# Define bitmap structure
119+
class mtmd_bitmap(Structure):
120+
_fields_ = [
121+
("nx", c_uint32),
122+
("ny", c_uint32),
123+
("data", POINTER(c_uint8)), # Vector represented as pointer
124+
("id", c_char_p),
125+
("is_audio", c_bool),
126+
]
127+
128+
# Define image tokens structure
129+
class mtmd_image_tokens(Structure):
130+
_fields_ = [
131+
("nx", c_uint32),
132+
("ny", c_uint32),
133+
("use_mrope_pos", c_bool),
134+
("batch_f32", c_void_p), # clip_image_f32_batch
135+
("id", c_char_p),
136+
]
38137

39-
# Specify the base name of the shared library to load
138+
# Define audio tokens structure
139+
class mtmd_audio_tokens(Structure):
140+
_fields_ = [
141+
("n_tokens", c_uint32),
142+
("batch_f32", c_void_p), # clip_image_f32_batch
143+
("id", c_char_p),
144+
]
145+
146+
# Define input chunk structure
147+
class mtmd_input_chunk(Structure):
148+
_fields_ = [
149+
("type", mtmd_input_chunk_type),
150+
("tokens_text", POINTER(llama_cpp.llama_token)), # Vector represented as pointer
151+
("tokens_image", c_void_p), # mtmd_image_tokens_ptr
152+
("tokens_audio", c_void_p), # mtmd_audio_tokens_ptr
153+
]
154+
155+
# Define input chunks structure
156+
class mtmd_input_chunks(Structure):
157+
_fields_ = [
158+
("entries", POINTER(mtmd_input_chunk)), # Vector represented as pointer
159+
]
160+
161+
# Define context pointer type
162+
mtmd_context_p = NewType("mtmd_context_p", int)
163+
mtmd_context_p_ctypes = c_void_p
164+
165+
# Define bitmap pointer type
166+
mtmd_bitmap_p = NewType("mtmd_bitmap_p", int)
167+
mtmd_bitmap_p_ctypes = c_void_p
168+
169+
# Define input chunks pointer type
170+
mtmd_input_chunks_p = NewType("mtmd_input_chunks_p", int)
171+
mtmd_input_chunks_p_ctypes = c_void_p
172+
173+
# Define input chunk pointer type
174+
mtmd_input_chunk_p = NewType("mtmd_input_chunk_p", int)
175+
mtmd_input_chunk_p_ctypes = c_void_p
176+
177+
# Define image tokens pointer type
178+
mtmd_image_tokens_p = NewType("mtmd_image_tokens_p", int)
179+
mtmd_image_tokens_p_ctypes = c_void_p
180+
181+
# Define audio tokens pointer type
182+
mtmd_audio_tokens_p = NewType("mtmd_audio_tokens_p", int)
183+
mtmd_audio_tokens_p_ctypes = c_void_p
184+
185+
# Load the library
40186
_libmtmd_base_name = "mtmd"
41187
_libmtmd_override_path = os.environ.get("MTMD_CPP_LIB")
42188
_libmtmd_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libmtmd_override_path is None else pathlib.Path()
43189

44190
# Load the library
45191
_libmtmd = load_shared_library(_libmtmd_base_name, _libmtmd_base_path)
46-
47192
ctypes_function = ctypes_function_for_shared_library(_libmtmd)
48193

49194
################################################

tests/monalisa.jpg

529 KB
Loading

tests/test_llava.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import multiprocessing
2+
import ctypes
3+
4+
from huggingface_hub import hf_hub_download
5+
6+
import pytest
7+
8+
import llama_cpp
9+
10+
@pytest.fixture
11+
def mmproj_model_path():
12+
repo_id = "second-state/Llava-v1.5-7B-GGUF"
13+
filename = "llava-v1.5-7b-mmproj-model-f16.gguf"
14+
model_path = hf_hub_download(repo_id, filename)
15+
return model_path
16+
17+
@pytest.fixture
18+
def llava_cpp_model_path():
19+
repo_id = "second-state/Llava-v1.5-7B-GGUF"
20+
filename = "llava-v1.5-7b-Q8_0.gguf"
21+
model_path = hf_hub_download(repo_id, filename)
22+
return model_path
23+
24+
def test_real_llava(llava_cpp_model_path, mmproj_model_path):
25+
print("initializing model")
26+
model = llama_cpp.Llama(
27+
llava_cpp_model_path,
28+
n_ctx=2048,
29+
n_batch=512,
30+
n_threads=multiprocessing.cpu_count(),
31+
n_threads_batch=multiprocessing.cpu_count(),
32+
logits_all=False,
33+
verbose=False,
34+
)
35+
36+
# Initialize the LLaVA chat handler
37+
from llama_cpp.llama_chat_format import Llava15ChatHandler
38+
print("initializing chat handler")
39+
chat_handler = Llava15ChatHandler(clip_model_path=mmproj_model_path, llama_model=model)
40+
41+
# Create a chat message with the image
42+
print("creating chat message")
43+
messages = [
44+
{
45+
"role": "user",
46+
"content": [
47+
{
48+
"type": "image_url",
49+
"image_url": "./tests/monalisa.jpg"
50+
},
51+
{
52+
"type": "text",
53+
"text": "Do you know who drew this painting?"
54+
}
55+
]
56+
}
57+
]
58+
59+
# Generate response
60+
print("generating response")
61+
response = chat_handler(
62+
llama=model,
63+
messages=messages,
64+
max_tokens=200,
65+
temperature=0.2,
66+
top_p=0.95,
67+
stream=False
68+
)
69+
70+
print("response", response)
71+
# Check that we got a response
72+
assert response is not None
73+
assert "choices" in response
74+
assert len(response["choices"]) > 0
75+
assert "message" in response["choices"][0]
76+
assert "content" in response["choices"][0]["message"]
77+
78+
# The response should mention Leonardo da Vinci
79+
content = response["choices"][0]["message"]["content"].lower()
80+
assert "leonardo" in content and "vinci" in content # Artist name should be in response

0 commit comments

Comments
 (0)