Skip to content

Commit 4cf4b15

Browse files
committed
migrate clip to mtmd
1 parent ffff841 commit 4cf4b15

File tree

2 files changed

+14
-58
lines changed

2 files changed

+14
-58
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,10 @@ def free_lora_adapter():
478478
f"Using fallback chat format: {self.chat_format}", file=sys.stderr
479479
)
480480

481+
if self.chat_handler is not None:
482+
if isinstance(self.chat_handler, llama_chat_format.Llava15ChatHandler):
483+
self.chat_handler.initialize_mtmd_context(self)
484+
481485
self._sampler = None
482486

483487
@property

llama_cpp/llama_chat_format.py

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2716,9 +2716,7 @@ class Llava15ChatHandler:
27162716
"{% endif %}"
27172717
)
27182718

2719-
def __init__(self, clip_model_path: str, llama_model: llama.Llama, verbose: bool = True):
2720-
import llama_cpp.mtmd_cpp as mtmd_cpp
2721-
2719+
def __init__(self, clip_model_path: str, llama_model: Optional[llama.Llama] = None, verbose: bool = True):
27222720
self.clip_model_path = clip_model_path
27232721
self.verbose = verbose
27242722
self._mtmd_cpp = mtmd_cpp
@@ -2763,6 +2761,15 @@ def mtmd_free():
27632761

27642762
self._exit_stack.callback(mtmd_free)
27652763

2764+
def __call__(self, *args, **kwargs):
2765+
if self.clip_ctx is None:
2766+
# Initialize MTMD context with the llama model from the first argument
2767+
if len(args) > 0 and isinstance(args[0], llama.Llama):
2768+
self.initialize_mtmd_context(args[0])
2769+
else:
2770+
raise ValueError("MTMD context not initialized. Please call initialize_mtmd_context with a llama model first.")
2771+
return super().__call__(*args, **kwargs)
2772+
27662773
def load_image(self, image_url: str) -> bytes:
27672774
return self._load_image(image_url)
27682775

@@ -3629,61 +3636,6 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
36293636
remaining = ""
36303637
return split_text
36313638

3632-
def eval_image(self, llama: llama.Llama, image_url: str):
3633-
image_bytes = self.load_image(image_url)
3634-
3635-
# Create bitmap manager if not exists
3636-
if self._bitmap_manager is None:
3637-
self._bitmap_manager = self._mtmd_cpp.BitmapManager()
3638-
3639-
# Create bitmap from bytes
3640-
if not self._bitmap_manager.add_from_memory(self.clip_ctx, image_bytes):
3641-
raise ValueError("Failed to create bitmap from image bytes")
3642-
3643-
# Create input chunks for the bitmap
3644-
chunks = self._mtmd_cpp.mtmd_input_chunks_init()
3645-
if chunks is None:
3646-
raise ValueError("Failed to create input chunks")
3647-
3648-
# Create input text with media marker
3649-
# Get media marker from context params
3650-
params = self._mtmd_cpp.mtmd_context_params_default()
3651-
text = self._mtmd_cpp.mtmd_input_text()
3652-
text.text = params.media_marker if params.media_marker else self._mtmd_cpp.mtmd_default_marker()
3653-
text.add_special = False
3654-
text.parse_special = True
3655-
3656-
# Tokenize with bitmap
3657-
if self._mtmd_cpp.mtmd_tokenize(self.clip_ctx, chunks, text, self._bitmap_manager.c_ptr(), len(self._bitmap_manager.entries)) != 0:
3658-
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3659-
raise ValueError("Failed to tokenize image")
3660-
3661-
# Get new n_past after evaluation
3662-
n_past = ctypes.c_int(llama.n_tokens)
3663-
n_past_p = ctypes.pointer(n_past)
3664-
3665-
# Evaluate chunks
3666-
if self._mtmd_cpp.mtmd_helper_eval_chunks(
3667-
self.clip_ctx,
3668-
llama.ctx,
3669-
chunks,
3670-
llama.n_tokens,
3671-
0, # seq_id
3672-
llama.n_batch,
3673-
True, # logits_last
3674-
n_past_p
3675-
) != 0:
3676-
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3677-
raise ValueError("Failed to evaluate chunks")
3678-
3679-
# Update n_tokens
3680-
llama.input_ids[llama.n_tokens : n_past.value] = -1
3681-
llama.n_tokens = n_past.value
3682-
3683-
# Cleanup
3684-
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3685-
self._bitmap_manager.clear()
3686-
36873639

36883640
def _accumulate_chunks(
36893641
chunks_iterator: Iterator[llama_types.CreateCompletionStreamResponse],

0 commit comments

Comments
 (0)