@@ -276,10 +276,11 @@ def __init__(
276276 self .verbose = verbose
277277 self ._exit_stack = ExitStack ()
278278
279- ctx = llama_cpp .llama_new_context_with_model (self .model .model , self .params )
279+ ctx = llama_cpp .llama_init_from_model (self .model .model , self .params )
280280
281281 if ctx is None :
282- raise ValueError ("Failed to create llama_context" )
282+ llama_cpp .llama_model_free (self .model .model )
283+ raise ValueError ("Failed to create context with model" )
283284
284285 self .ctx = ctx
285286
@@ -445,15 +446,39 @@ def decode(self, batch: LlamaBatch):
445446 def set_n_threads (self , n_threads : int , n_threads_batch : int ):
446447 llama_cpp .llama_set_n_threads (self .ctx , n_threads , n_threads_batch )
447448
449+ def n_threads (self ) -> int :
450+ return llama_cpp .llama_n_threads (self .ctx )
451+
452+ def n_threads_batch (self ) -> int :
453+ return llama_cpp .llama_n_threads_batch (self .ctx )
454+
455+ def set_causal_attn (self , causal_attn : bool ):
456+ llama_cpp .llama_set_causal_attn (self .ctx , causal_attn )
457+
458+ def set_warmup (self , warmup : bool ):
459+ llama_cpp .llama_set_warmup (self .ctx , warmup )
460+
461+ def synchronize (self ):
462+ llama_cpp .llama_synchronize (self .ctx )
463+
448464 def get_logits (self ):
449465 return llama_cpp .llama_get_logits (self .ctx )
450466
451467 def get_logits_ith (self , i : int ):
452468 return llama_cpp .llama_get_logits_ith (self .ctx , i )
453469
470+ def set_embeddings (self , embeddings : bool ):
471+ llama_cpp .llama_set_embeddings (self .ctx , embeddings )
472+
454473 def get_embeddings (self ):
455474 return llama_cpp .llama_get_embeddings (self .ctx )
456475
476+ def get_embeddings_ith (self , i : int ):
477+ return llama_cpp .llama_get_embeddings_ith (self .ctx , i )
478+
479+ def get_embeddings_seq (self , seq_id : int ):
480+ return llama_cpp .llama_get_embeddings_seq (self .ctx , seq_id )
481+
457482 # Sampling functions
458483
459484 def set_rng_seed (self , seed : int ):
0 commit comments