Skip to content

Commit 8839d47

Browse files
committed
migrate clip to mtmd
1 parent 44e2893 commit 8839d47

File tree

2 files changed

+14
-58
lines changed

2 files changed

+14
-58
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,10 @@ def free_lora_adapter():
478478
f"Using fallback chat format: {self.chat_format}", file=sys.stderr
479479
)
480480

481+
if self.chat_handler is not None:
482+
if isinstance(self.chat_handler, llama_chat_format.Llava15ChatHandler):
483+
self.chat_handler.initialize_mtmd_context(self)
484+
481485
self._sampler = None
482486

483487
@property

llama_cpp/llama_chat_format.py

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2722,9 +2722,7 @@ class Llava15ChatHandler:
27222722
"{% endif %}"
27232723
)
27242724

2725-
def __init__(self, clip_model_path: str, llama_model: llama.Llama, verbose: bool = True):
2726-
import llama_cpp.mtmd_cpp as mtmd_cpp
2727-
2725+
def __init__(self, clip_model_path: str, llama_model: Optional[llama.Llama] = None, verbose: bool = True):
27282726
self.clip_model_path = clip_model_path
27292727
self.verbose = verbose
27302728
self._mtmd_cpp = mtmd_cpp
@@ -2769,6 +2767,15 @@ def mtmd_free():
27692767

27702768
self._exit_stack.callback(mtmd_free)
27712769

2770+
def __call__(self, *args, **kwargs):
2771+
if self.clip_ctx is None:
2772+
# Initialize MTMD context with the llama model from the first argument
2773+
if len(args) > 0 and isinstance(args[0], llama.Llama):
2774+
self.initialize_mtmd_context(args[0])
2775+
else:
2776+
raise ValueError("MTMD context not initialized. Please call initialize_mtmd_context with a llama model first.")
2777+
return super().__call__(*args, **kwargs)
2778+
27722779
def load_image(self, image_url: str) -> bytes:
27732780
return self._load_image(image_url)
27742781

@@ -3638,61 +3645,6 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
36383645
remaining = ""
36393646
return split_text
36403647

3641-
def eval_image(self, llama: llama.Llama, image_url: str):
3642-
image_bytes = self.load_image(image_url)
3643-
3644-
# Create bitmap manager if not exists
3645-
if self._bitmap_manager is None:
3646-
self._bitmap_manager = self._mtmd_cpp.BitmapManager()
3647-
3648-
# Create bitmap from bytes
3649-
if not self._bitmap_manager.add_from_memory(self.clip_ctx, image_bytes):
3650-
raise ValueError("Failed to create bitmap from image bytes")
3651-
3652-
# Create input chunks for the bitmap
3653-
chunks = self._mtmd_cpp.mtmd_input_chunks_init()
3654-
if chunks is None:
3655-
raise ValueError("Failed to create input chunks")
3656-
3657-
# Create input text with media marker
3658-
# Get media marker from context params
3659-
params = self._mtmd_cpp.mtmd_context_params_default()
3660-
text = self._mtmd_cpp.mtmd_input_text()
3661-
text.text = params.media_marker if params.media_marker else self._mtmd_cpp.mtmd_default_marker()
3662-
text.add_special = False
3663-
text.parse_special = True
3664-
3665-
# Tokenize with bitmap
3666-
if self._mtmd_cpp.mtmd_tokenize(self.clip_ctx, chunks, text, self._bitmap_manager.c_ptr(), len(self._bitmap_manager.entries)) != 0:
3667-
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3668-
raise ValueError("Failed to tokenize image")
3669-
3670-
# Get new n_past after evaluation
3671-
n_past = ctypes.c_int(llama.n_tokens)
3672-
n_past_p = ctypes.pointer(n_past)
3673-
3674-
# Evaluate chunks
3675-
if self._mtmd_cpp.mtmd_helper_eval_chunks(
3676-
self.clip_ctx,
3677-
llama.ctx,
3678-
chunks,
3679-
llama.n_tokens,
3680-
0, # seq_id
3681-
llama.n_batch,
3682-
True, # logits_last
3683-
n_past_p
3684-
) != 0:
3685-
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3686-
raise ValueError("Failed to evaluate chunks")
3687-
3688-
# Update n_tokens
3689-
llama.input_ids[llama.n_tokens : n_past.value] = -1
3690-
llama.n_tokens = n_past.value
3691-
3692-
# Cleanup
3693-
self._mtmd_cpp.mtmd_input_chunks_free(chunks)
3694-
self._bitmap_manager.clear()
3695-
36963648

36973649
def _accumulate_chunks(
36983650
chunks_iterator: Iterator[llama_types.CreateCompletionStreamResponse],

0 commit comments

Comments
 (0)