@@ -2713,22 +2713,27 @@ class Llava15ChatHandler:
27132713 "{% endif %}"
27142714 )
27152715
2716- def __init__ (self , clip_model_path : str , llama_model : llama .Llama , verbose : bool = True ):
2717- import llama_cpp .mtmd_cpp as mtmd_cpp
2718-
2716+ def __init__ (self , clip_model_path : str , llama_model : Optional [llama .Llama ] = None , verbose : bool = True ):
27192717 self .clip_model_path = clip_model_path
27202718 self .verbose = verbose
2721-
2722- self ._mtmd_cpp = mtmd_cpp
2719+ self ._mtmd_cpp = None
27232720 self ._exit_stack = ExitStack ()
27242721 self ._bitmap_manager = None
2722+ self .clip_ctx = None
2723+ self ._params = None
27252724
27262725 if not os .path .exists (clip_model_path ):
27272726 raise ValueError (f"Clip model path does not exist: { clip_model_path } " )
27282727
2729- # We'll initialize the clip context later when we have the llama model
2730- self .clip_ctx = None
2731- self ._params = None
2728+ # Initialize MTMD context if model is provided
2729+ if llama_model is not None :
2730+ self .initialize_mtmd_context (llama_model )
2731+
2732+ def initialize_mtmd_context (self , llama_model : llama .Llama ):
2733+ """Initialize the MTMD context with a llama model."""
2734+ import llama_cpp .mtmd_cpp as mtmd_cpp
2735+ self ._mtmd_cpp = mtmd_cpp
2736+
27322737 with suppress_stdout_stderr (disable = self .verbose ):
27332738 params = self ._mtmd_cpp .mtmd_context_params_default ()
27342739 params .use_gpu = True # TODO: Make configurable
@@ -2748,10 +2753,22 @@ def mtmd_free():
27482753
27492754 self ._exit_stack .callback (mtmd_free )
27502755
2756+ def __call__ (self , * args , ** kwargs ):
2757+ if self .clip_ctx is None :
2758+ # Initialize MTMD context with the llama model from the first argument
2759+ if len (args ) > 0 and isinstance (args [0 ], llama .Llama ):
2760+ self .initialize_mtmd_context (args [0 ])
2761+ else :
2762+ raise ValueError ("MTMD context not initialized. Please call initialize_mtmd_context with a llama model first." )
2763+ return super ().__call__ (* args , ** kwargs )
2764+
27512765 def load_image (self , image_url : str ) -> bytes :
27522766 return self ._load_image (image_url )
27532767
27542768 def eval_image (self , llama : llama .Llama , image_url : str ):
2769+ if self .clip_ctx is None :
2770+ self .initialize_mtmd_context (llama )
2771+
27552772 image_bytes = self .load_image (image_url )
27562773
27572774 # Create bitmap manager if not exists
@@ -3481,61 +3498,6 @@ def split_text_on_image_urls(text: str, image_urls: List[str]):
34813498 remaining = ""
34823499 return split_text
34833500
3484- def eval_image (self , llama : llama .Llama , image_url : str ):
3485- image_bytes = self .load_image (image_url )
3486-
3487- # Create bitmap manager if not exists
3488- if self ._bitmap_manager is None :
3489- self ._bitmap_manager = self ._mtmd_cpp .BitmapManager ()
3490-
3491- # Create bitmap from bytes
3492- if not self ._bitmap_manager .add_from_memory (self .clip_ctx , image_bytes ):
3493- raise ValueError ("Failed to create bitmap from image bytes" )
3494-
3495- # Create input chunks for the bitmap
3496- chunks = self ._mtmd_cpp .mtmd_input_chunks_init ()
3497- if chunks is None :
3498- raise ValueError ("Failed to create input chunks" )
3499-
3500- # Create input text with media marker
3501- # Get media marker from context params
3502- params = self ._mtmd_cpp .mtmd_context_params_default ()
3503- text = self ._mtmd_cpp .mtmd_input_text ()
3504- text .text = params .media_marker if params .media_marker else self ._mtmd_cpp .mtmd_default_marker ()
3505- text .add_special = False
3506- text .parse_special = True
3507-
3508- # Tokenize with bitmap
3509- if self ._mtmd_cpp .mtmd_tokenize (self .clip_ctx , chunks , text , self ._bitmap_manager .c_ptr (), len (self ._bitmap_manager .entries )) != 0 :
3510- self ._mtmd_cpp .mtmd_input_chunks_free (chunks )
3511- raise ValueError ("Failed to tokenize image" )
3512-
3513- # Get new n_past after evaluation
3514- n_past = ctypes .c_int (llama .n_tokens )
3515- n_past_p = ctypes .pointer (n_past )
3516-
3517- # Evaluate chunks
3518- if self ._mtmd_cpp .mtmd_helper_eval_chunks (
3519- self .clip_ctx ,
3520- llama .ctx ,
3521- chunks ,
3522- llama .n_tokens ,
3523- 0 , # seq_id
3524- llama .n_batch ,
3525- True , # logits_last
3526- n_past_p
3527- ) != 0 :
3528- self ._mtmd_cpp .mtmd_input_chunks_free (chunks )
3529- raise ValueError ("Failed to evaluate chunks" )
3530-
3531- # Update n_tokens
3532- llama .input_ids [llama .n_tokens : n_past .value ] = - 1
3533- llama .n_tokens = n_past .value
3534-
3535- # Cleanup
3536- self ._mtmd_cpp .mtmd_input_chunks_free (chunks )
3537- self ._bitmap_manager .clear ()
3538-
35393501
35403502def _accumulate_chunks (
35413503 chunks_iterator : Iterator [llama_types .CreateCompletionStreamResponse ],
0 commit comments