Skip to content

Commit 646beb7

Browse files
committed
migrate clip to mtmd
1 parent c66b6e9 commit 646beb7

File tree

2 files changed

+29
-63
lines changed

2 files changed

+29
-63
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,10 @@ def free_lora_adapter():
481481
f"Using fallback chat format: {self.chat_format}", file=sys.stderr
482482
)
483483

484+
if self.chat_handler is not None:
485+
if isinstance(self.chat_handler, llama_chat_format.Llava15ChatHandler):
486+
self.chat_handler.initialize_mtmd_context(self)
487+
484488
self._sampler = None
485489

486490
@property

llama_cpp/llama_chat_format.py

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2713,22 +2713,27 @@ class Llava15ChatHandler:
27132713
"{% endif %}"
27142714
)
27152715

2716-
def __init__(self, clip_model_path: str, llama_model: llama.Llama, verbose: bool = True):
2717-
import llama_cpp.mtmd_cpp as mtmd_cpp
2718-
2716+
def __init__(self, clip_model_path: str, llama_model: Optional[llama.Llama] = None, verbose: bool = True):
27192717
self.clip_model_path = clip_model_path
27202718
self.verbose = verbose
2721-
2722-
self._mtmd_cpp = mtmd_cpp
2719+
self._mtmd_cpp = None
27232720
self._exit_stack = ExitStack()
27242721
self._bitmap_manager = None
2722+
self.clip_ctx = None
2723+
self._params = None
27252724

27262725
if not os.path.exists(clip_model_path):
27272726
raise ValueError(f"Clip model path does not exist: {clip_model_path}")
27282727

2729-
# We'll initialize the clip context later when we have the llama model
2730-
self.clip_ctx = None
2731-
self._params = None
2728+
# Initialize MTMD context if model is provided
2729+
if llama_model is not None:
2730+
self.initialize_mtmd_context(llama_model)
2731+
2732+
def initialize_mtmd_context(self, llama_model: llama.Llama):
2733+
"""Initialize the MTMD context with a llama model."""
2734+
import llama_cpp.mtmd_cpp as mtmd_cpp
2735+
self._mtmd_cpp = mtmd_cpp
2736+
27322737
with suppress_stdout_stderr(disable=self.verbose):
27332738
params = self._mtmd_cpp.mtmd_context_params_default()
27342739
params.use_gpu = True # TODO: Make configurable
@@ -2748,10 +2753,22 @@ def mtmd_free():
27482753

27492754
self._exit_stack.callback(mtmd_free)
27502755

2756+
def __call__(self, *args, **kwargs):
2757+
if self.clip_ctx is None:
2758+
# Initialize MTMD context with the llama model from the first argument
2759+
if len(args) > 0 and isinstance(args[0], llama.Llama):
2760+
self.initialize_mtmd_context(args[0])
2761+
else:
2762+
raise ValueError("MTMD context not initialized. Please call initialize_mtmd_context with a llama model first.")
2763+
return super().__call__(*args, **kwargs)
2764+
27512765
def load_image(self, image_url: str) -> bytes:
27522766
return self._load_image(image_url)
27532767

27542768
def eval_image(self, llama: llama.Llama, image_url: str):
2769+
if self.clip_ctx is None:
2770+
self.initialize_mtmd_context(llama)
2771+
27552772
image_bytes = self.load_image(image_url)
27562773

27572774
# Create bitmap manager if not exists
@@ -3481,61 +3498,6 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
34813498
remaining = ""
34823499
return split_text
34833500

3484-
def eval_image(self, llama: llama.Llama, image_url: str):
3485-
image_bytes = self.load_image(image_url)
3486-
3487-
# Create bitmap manager if not exists
3488-
if self._bitmap_manager is None:
3489-
self._bitmap_manager = self._mtmd_cpp.BitmapManager()
3490-
3491-
# Create bitmap from bytes
3492-
if not self._bitmap_manager.add_from_memory(self.clip_ctx, image_bytes):
3493-
raise ValueError("Failed to create bitmap from image bytes")
3494-
3495-
# Create input chunks for the bitmap
3496-
chunks = self._mtmd_cpp.mtmd_input_chunks_init()
3497-
if chunks is None:
3498-
raise ValueError("Failed to create input chunks")
3499-
3500-
# Create input text with media marker
3501-
# Get media marker from context params
3502-
params = self._mtmd_cpp.mtmd_context_params_default()
3503-
text = self._mtmd_cpp.mtmd_input_text()
3504-
text.text = params.media_marker if params.media_marker else self._mtmd_cpp.mtmd_default_marker()
3505-
text.add_special = False
3506-
text.parse_special = True
3507-
3508-
# Tokenize with bitmap
3509-
if self._mtmd_cpp.mtmd_tokenize(self.clip_ctx, chunks, text, self._bitmap_manager.c_ptr(), len(self._bitmap_manager.entries)) != 0:
3510-
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3511-
raise ValueError("Failed to tokenize image")
3512-
3513-
# Get new n_past after evaluation
3514-
n_past = ctypes.c_int(llama.n_tokens)
3515-
n_past_p = ctypes.pointer(n_past)
3516-
3517-
# Evaluate chunks
3518-
if self._mtmd_cpp.mtmd_helper_eval_chunks(
3519-
self.clip_ctx,
3520-
llama.ctx,
3521-
chunks,
3522-
llama.n_tokens,
3523-
0, # seq_id
3524-
llama.n_batch,
3525-
True, # logits_last
3526-
n_past_p
3527-
) != 0:
3528-
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3529-
raise ValueError("Failed to evaluate chunks")
3530-
3531-
# Update n_tokens
3532-
llama.input_ids[llama.n_tokens : n_past.value] = -1
3533-
llama.n_tokens = n_past.value
3534-
3535-
# Cleanup
3536-
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3537-
self._bitmap_manager.clear()
3538-
35393501

35403502
def _accumulate_chunks(
35413503
chunks_iterator: Iterator[llama_types.CreateCompletionStreamResponse],

0 commit comments

Comments
 (0)