Skip to content

Commit 528a518

Browse files
committed
Update LlamaContext API and Release the model pointer when the ctx was failed to create context with model
1 parent 8b62fb0 commit 528a518

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

llama_cpp/_internals.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,11 @@ def __init__(
276276
self.verbose = verbose
277277
self._exit_stack = ExitStack()
278278

279-
ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params)
279+
ctx = llama_cpp.llama_init_from_model(self.model.model, self.params)
280280

281281
if ctx is None:
282-
raise ValueError("Failed to create llama_context")
282+
llama_cpp.llama_model_free(self.model.model)
283+
raise ValueError("Failed to create context with model")
283284

284285
self.ctx = ctx
285286

@@ -445,15 +446,39 @@ def decode(self, batch: LlamaBatch):
445446
def set_n_threads(self, n_threads: int, n_threads_batch: int):
446447
llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch)
447448

449+
def n_threads(self) -> int:
450+
return llama_cpp.llama_n_threads(self.ctx)
451+
452+
def n_threads_batch(self) -> int:
453+
return llama_cpp.llama_n_threads_batch(self.ctx)
454+
455+
def set_causal_attn(self, causal_attn: bool):
456+
llama_cpp.llama_set_causal_attn(self.ctx, causal_attn)
457+
458+
def set_warmup(self, warmup: bool):
459+
llama_cpp.llama_set_warmup(self.ctx, warmup)
460+
461+
def synchronize(self):
462+
llama_cpp.llama_synchronize(self.ctx)
463+
448464
def get_logits(self):
449465
return llama_cpp.llama_get_logits(self.ctx)
450466

451467
def get_logits_ith(self, i: int):
452468
return llama_cpp.llama_get_logits_ith(self.ctx, i)
453469

470+
def set_embeddings(self, embeddings: bool):
471+
llama_cpp.llama_set_embeddings(self.ctx, embeddings)
472+
454473
def get_embeddings(self):
455474
return llama_cpp.llama_get_embeddings(self.ctx)
456475

476+
def get_embeddings_ith(self, i: int):
477+
return llama_cpp.llama_get_embeddings_ith(self.ctx, i)
478+
479+
def get_embeddings_seq(self, seq_id: int):
480+
return llama_cpp.llama_get_embeddings_seq(self.ctx, seq_id)
481+
457482
# Sampling functions
458483

459484
def set_rng_seed(self, seed: int):

0 commit comments

Comments
 (0)