Skip to content

Commit ba0ae7f

Browse files
committed
migrate clip to mtmd
1 parent 5c41448 commit ba0ae7f

File tree

4 files changed

+279
-57
lines changed

4 files changed

+279
-57
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 52 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2722,7 +2722,7 @@ class Llava15ChatHandler:
27222722
"{% endif %}"
27232723
)
27242724

2725-
def __init__(self, clip_model_path: str, verbose: bool = True):
2725+
def __init__(self, clip_model_path: str, llama_model: llama.Llama, verbose: bool = True):
27262726
import llama_cpp.mtmd_cpp as mtmd_cpp
27272727

27282728
self.clip_model_path = clip_model_path
@@ -3639,62 +3639,59 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
36393639
return split_text
36403640

36413641
def eval_image(self, llama: llama.Llama, image_url: str):
3642-
import llama_cpp
3642+
image_bytes = self.load_image(image_url)
3643+
3644+
# Create bitmap manager if not exists
3645+
if self._bitmap_manager is None:
3646+
self._bitmap_manager = self._mtmd_cpp.BitmapManager()
3647+
3648+
# Create bitmap from bytes
3649+
if not self._bitmap_manager.add_from_memory(self.clip_ctx, image_bytes):
3650+
raise ValueError("Failed to create bitmap from image bytes")
3651+
3652+
# Create input chunks for the bitmap
3653+
chunks = self._mtmd_cpp.mtmd_input_chunks_init()
3654+
if chunks is None:
3655+
raise ValueError("Failed to create input chunks")
3656+
3657+
# Create input text with media marker
3658+
# Get media marker from context params
3659+
params = self._mtmd_cpp.mtmd_context_params_default()
3660+
text = self._mtmd_cpp.mtmd_input_text()
3661+
text.text = params.media_marker if params.media_marker else self._mtmd_cpp.mtmd_default_marker()
3662+
text.add_special = False
3663+
text.parse_special = True
3664+
3665+
# Tokenize with bitmap
3666+
if self._mtmd_cpp.mtmd_tokenize(self.clip_ctx, chunks, text, self._bitmap_manager.c_ptr(), len(self._bitmap_manager.entries)) != 0:
3667+
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3668+
raise ValueError("Failed to tokenize image")
3669+
3670+
# Get new n_past after evaluation
3671+
n_past = ctypes.c_int(llama.n_tokens)
3672+
n_past_p = ctypes.pointer(n_past)
36433673

3644-
n_tokens = 256
3645-
if llama.n_tokens + n_tokens > llama.n_ctx():
3646-
raise ValueError(
3647-
f"Prompt exceeds n_ctx: {llama.n_tokens + n_tokens} > {llama.n_ctx()}"
3648-
)
3674+
# Evaluate chunks
3675+
if self._mtmd_cpp.mtmd_helper_eval_chunks(
3676+
self.clip_ctx,
3677+
llama.ctx,
3678+
chunks,
3679+
llama.n_tokens,
3680+
0, # seq_id
3681+
llama.n_batch,
3682+
True, # logits_last
3683+
n_past_p
3684+
) != 0:
3685+
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3686+
raise ValueError("Failed to evaluate chunks")
3687+
3688+
# Update n_tokens
3689+
llama.input_ids[llama.n_tokens : n_past.value] = -1
3690+
llama.n_tokens = n_past.value
36493691

3650-
img_bytes = self.load_image(image_url)
3651-
img_u8_p = self._mtmd_cpp.clip_image_u8_init()
3652-
if not self._mtmd_cpp.clip_image_load_from_bytes(
3653-
ctypes.create_string_buffer(img_bytes, len(img_bytes)),
3654-
ctypes.c_size_t(len(img_bytes)),
3655-
img_u8_p,
3656-
):
3657-
self._mtmd_cpp.clip_image_u8_free(img_u8_p)
3658-
raise ValueError("Failed to load image.")
3659-
3660-
img_f32_p = self._mtmd_cpp.clip_image_f32_batch_init()
3661-
if not self._mtmd_cpp.clip_image_preprocess(self.clip_ctx, img_u8_p, img_f32_p):
3662-
self._mtmd_cpp.clip_image_f32_batch_free(img_f32_p)
3663-
self._mtmd_cpp.clip_image_u8_free(img_u8_p)
3664-
raise ValueError("Failed to preprocess image.")
3665-
3666-
n_embd = llama_cpp.llama_model_n_embd(llama._model.model)
3667-
embed = (ctypes.c_float * (n_tokens * n_embd))()
3668-
if not self._mtmd_cpp.clip_image_batch_encode(self.clip_ctx, llama.n_threads, img_f32_p, embed):
3669-
self._mtmd_cpp.clip_image_f32_batch_free(img_f32_p)
3670-
self._mtmd_cpp.clip_image_u8_free(img_u8_p)
3671-
raise ValueError("Failed to encode image.")
3672-
3673-
self._mtmd_cpp.clip_image_f32_batch_free(img_f32_p)
3674-
self._mtmd_cpp.clip_image_u8_free(img_u8_p)
3675-
llama_cpp.llama_set_causal_attn(llama.ctx, False)
3676-
3677-
seq_id_0 = (ctypes.c_int32 * 1)()
3678-
seq_ids = (ctypes.POINTER(ctypes.c_int32) * (n_tokens + 1))()
3679-
for i in range(n_tokens):
3680-
seq_ids[i] = seq_id_0
3681-
3682-
batch = llama_cpp.llama_batch()
3683-
batch.n_tokens = n_tokens
3684-
batch.token = None
3685-
batch.embd = embed
3686-
batch.pos = (ctypes.c_int32 * n_tokens)(*[i + llama.n_tokens for i in range(n_tokens)])
3687-
batch.seq_id = seq_ids
3688-
batch.n_seq_id = (ctypes.c_int32 * n_tokens)(*([1] * n_tokens))
3689-
batch.logits = (ctypes.c_int8 * n_tokens)()
3690-
3691-
if llama_cpp.llama_decode(llama.ctx, batch):
3692-
raise ValueError("Failed to decode image.")
3693-
3694-
llama_cpp.llama_set_causal_attn(llama.ctx, True)
3695-
# Required to avoid issues with hf tokenizer
3696-
llama.input_ids[llama.n_tokens : llama.n_tokens + n_tokens] = -1
3697-
llama.n_tokens += n_tokens
3692+
# Cleanup
3693+
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3694+
self._bitmap_manager.clear()
36983695

36993696

37003697
def _accumulate_chunks(

llama_cpp/mtmd_cpp.py

Lines changed: 147 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

33
import os
4+
import ctypes
45
from ctypes import (
56
c_bool,
67
c_char_p,
78
c_int,
9+
c_int32,
810
c_uint8,
911
c_uint32,
1012
c_float,
@@ -17,6 +19,7 @@
1719
)
1820
import pathlib
1921
from typing import (
22+
List,
2023
Union,
2124
NewType,
2225
Optional,
@@ -31,19 +34,161 @@
3134
)
3235

3336
if TYPE_CHECKING:
37+
from llama_cpp.llama_types import (
38+
llama_token,
39+
llama_pos,
40+
)
3441
from llama_cpp._ctypes_extensions import (
3542
CtypesArray,
43+
CtypesPointer,
3644
)
3745

46+
# Define input text structure
47+
class mtmd_input_text(Structure):
48+
_fields_ = [
49+
("text", c_char_p),
50+
("add_special", c_bool),
51+
("parse_special", c_bool),
52+
]
53+
54+
# Define context parameters structure
55+
class mtmd_context_params(Structure):
56+
_fields_ = [
57+
("use_gpu", c_bool),
58+
("print_timings", c_bool),
59+
("n_threads", c_int),
60+
("verbosity", c_int),
61+
("image_marker", c_char_p), # const char*
62+
("media_marker", c_char_p), # const char*
63+
]
64+
65+
# Define input chunk type enum
66+
mtmd_input_chunk_type = c_int
67+
(
68+
MTMD_INPUT_CHUNK_TYPE_TEXT,
69+
MTMD_INPUT_CHUNK_TYPE_IMAGE,
70+
MTMD_INPUT_CHUNK_TYPE_AUDIO,
71+
) = (0, 1, 2)
72+
73+
# Define slice template enum
74+
mtmd_slice_tmpl = c_int
75+
(
76+
MTMD_SLICE_TMPL_NONE,
77+
MTMD_SLICE_TMPL_MINICPMV_2_5,
78+
MTMD_SLICE_TMPL_MINICPMV_2_6,
79+
MTMD_SLICE_TMPL_LLAMA4,
80+
) = (0, 1, 2, 3)
81+
82+
# Define whisper filters structure
83+
class whisper_filters(Structure):
84+
_fields_ = [
85+
("n_mel", c_int),
86+
]
87+
88+
# Define mtmd_context structure
89+
class mtmd_context(Structure):
90+
_fields_ = [
91+
("ctx_v", c_void_p), # clip_ctx*
92+
("ctx_a", c_void_p), # clip_ctx*
93+
("text_model", c_void_p), # const llama_model*
94+
("image_embd_v", POINTER(c_float)), # std::vector<float>
95+
("print_timings", c_bool),
96+
("n_threads", c_int),
97+
("media_marker", c_char_p), # std::string
98+
("n_embd_text", c_int),
99+
("img_beg", c_char_p), # std::string
100+
("img_end", c_char_p), # std::string
101+
("aud_beg", c_char_p), # std::string
102+
("aud_end", c_char_p), # std::string
103+
("slice_tmpl", c_int), # mtmd_slice_tmpl
104+
("tok_ov_img_start", llama_cpp.llama_token),
105+
("tok_ov_img_end", llama_cpp.llama_token),
106+
("tok_slices_start", llama_cpp.llama_token),
107+
("tok_slices_end", llama_cpp.llama_token),
108+
("tok_sli_img_start", llama_cpp.llama_token),
109+
("tok_sli_img_end", llama_cpp.llama_token),
110+
("tok_sli_img_mid", llama_cpp.llama_token),
111+
("tok_row_end", llama_cpp.llama_token),
112+
("tok_row_end_trail", c_bool),
113+
("ov_img_first", c_bool),
114+
("use_mrope", c_bool),
115+
("w_filters", whisper_filters),
116+
]
117+
118+
# Define bitmap structure
119+
class mtmd_bitmap(Structure):
120+
_fields_ = [
121+
("nx", c_uint32),
122+
("ny", c_uint32),
123+
("data", POINTER(c_uint8)), # Vector represented as pointer
124+
("id", c_char_p),
125+
("is_audio", c_bool),
126+
]
127+
128+
# Define image tokens structure
129+
class mtmd_image_tokens(Structure):
130+
_fields_ = [
131+
("nx", c_uint32),
132+
("ny", c_uint32),
133+
("use_mrope_pos", c_bool),
134+
("batch_f32", c_void_p), # clip_image_f32_batch
135+
("id", c_char_p),
136+
]
38137

39-
# Specify the base name of the shared library to load
138+
# Define audio tokens structure
139+
class mtmd_audio_tokens(Structure):
140+
_fields_ = [
141+
("n_tokens", c_uint32),
142+
("batch_f32", c_void_p), # clip_image_f32_batch
143+
("id", c_char_p),
144+
]
145+
146+
# Define input chunk structure
147+
class mtmd_input_chunk(Structure):
148+
_fields_ = [
149+
("type", mtmd_input_chunk_type),
150+
("tokens_text", POINTER(llama_cpp.llama_token)), # Vector represented as pointer
151+
("tokens_image", c_void_p), # mtmd_image_tokens_ptr
152+
("tokens_audio", c_void_p), # mtmd_audio_tokens_ptr
153+
]
154+
155+
# Define input chunks structure
156+
class mtmd_input_chunks(Structure):
157+
_fields_ = [
158+
("entries", POINTER(mtmd_input_chunk)), # Vector represented as pointer
159+
]
160+
161+
# Define context pointer type
162+
mtmd_context_p = NewType("mtmd_context_p", int)
163+
mtmd_context_p_ctypes = c_void_p
164+
165+
# Define bitmap pointer type
166+
mtmd_bitmap_p = NewType("mtmd_bitmap_p", int)
167+
mtmd_bitmap_p_ctypes = c_void_p
168+
169+
# Define input chunks pointer type
170+
mtmd_input_chunks_p = NewType("mtmd_input_chunks_p", int)
171+
mtmd_input_chunks_p_ctypes = c_void_p
172+
173+
# Define input chunk pointer type
174+
mtmd_input_chunk_p = NewType("mtmd_input_chunk_p", int)
175+
mtmd_input_chunk_p_ctypes = c_void_p
176+
177+
# Define image tokens pointer type
178+
mtmd_image_tokens_p = NewType("mtmd_image_tokens_p", int)
179+
mtmd_image_tokens_p_ctypes = c_void_p
180+
181+
# Define audio tokens pointer type
182+
mtmd_audio_tokens_p = NewType("mtmd_audio_tokens_p", int)
183+
mtmd_audio_tokens_p_ctypes = c_void_p
184+
185+
# Load the library
40186
_libmtmd_base_name = "mtmd"
41187
_libmtmd_override_path = os.environ.get("MTMD_CPP_LIB")
42188
_libmtmd_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libmtmd_override_path is None else pathlib.Path()
43189

44190
# Load the library
45191
_libmtmd = load_shared_library(_libmtmd_base_name, _libmtmd_base_path)
46-
47192
ctypes_function = ctypes_function_for_shared_library(_libmtmd)
48193

49194
################################################

tests/monalisa.jpg

529 KB
Loading

tests/test_llava.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import multiprocessing
2+
import ctypes
3+
4+
from huggingface_hub import hf_hub_download
5+
6+
import pytest
7+
8+
import llama_cpp
9+
10+
@pytest.fixture
11+
def mmproj_model_path():
12+
repo_id = "second-state/Llava-v1.5-7B-GGUF"
13+
filename = "llava-v1.5-7b-mmproj-model-f16.gguf"
14+
model_path = hf_hub_download(repo_id, filename)
15+
return model_path
16+
17+
@pytest.fixture
18+
def llava_cpp_model_path():
19+
repo_id = "second-state/Llava-v1.5-7B-GGUF"
20+
filename = "llava-v1.5-7b-Q8_0.gguf"
21+
model_path = hf_hub_download(repo_id, filename)
22+
return model_path
23+
24+
def test_real_llava(llava_cpp_model_path, mmproj_model_path):
25+
print("initializing model")
26+
model = llama_cpp.Llama(
27+
llava_cpp_model_path,
28+
n_ctx=2048,
29+
n_batch=512,
30+
n_threads=multiprocessing.cpu_count(),
31+
n_threads_batch=multiprocessing.cpu_count(),
32+
logits_all=False,
33+
verbose=False,
34+
)
35+
36+
# Initialize the LLaVA chat handler
37+
from llama_cpp.llama_chat_format import Llava15ChatHandler
38+
print("initializing chat handler")
39+
chat_handler = Llava15ChatHandler(clip_model_path=mmproj_model_path, llama_model=model)
40+
41+
# Create a chat message with the image
42+
print("creating chat message")
43+
messages = [
44+
{
45+
"role": "user",
46+
"content": [
47+
{
48+
"type": "image_url",
49+
"image_url": "./tests/monalisa.jpg"
50+
},
51+
{
52+
"type": "text",
53+
"text": "Do you know who drew this painting?"
54+
}
55+
]
56+
}
57+
]
58+
59+
# Generate response
60+
print("generating response")
61+
response = chat_handler(
62+
llama=model,
63+
messages=messages,
64+
max_tokens=200,
65+
temperature=0.2,
66+
top_p=0.95,
67+
stream=False
68+
)
69+
70+
print("response", response)
71+
# Check that we got a response
72+
assert response is not None
73+
assert "choices" in response
74+
assert len(response["choices"]) > 0
75+
assert "message" in response["choices"][0]
76+
assert "content" in response["choices"][0]["message"]
77+
78+
# The response should mention Leonardo da Vinci
79+
content = response["choices"][0]["message"]["content"].lower()
80+
assert "leonardo" in content and "vinci" in content # Artist name should be in response

0 commit comments

Comments
 (0)