1111from .llama_types import *
1212
1313
14+ class LlamaCache :
15+ """Cache for a llama.cpp model.
16+
17+ NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
18+ completion. It does not actually cache the results."""
19+
20+ pass
21+
22+
1423class Llama :
1524 """High-level Python wrapper for a llama.cpp model."""
1625
@@ -82,6 +91,14 @@ def __init__(
8291 self .n_past = 0
8392 self .all_logits : List [List [float ]] = [] # TODO: Use an array instead of a list.
8493
94+ ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
95+ ### saving and restoring state, this allows us to continue a completion if the last
96+ ### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
97+ ### because it does not take into account stop tokens which have been processed by the model.
98+ self ._completion_bytes : List [bytes ] = []
99+ self ._cache : Optional [LlamaCache ] = None
100+ ###
101+
85102 self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
86103
87104 if not os .path .exists (model_path ):
@@ -135,6 +152,14 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
135152 output += llama_cpp .llama_token_to_str (self .ctx , token )
136153 return output
137154
155+ def set_cache (self , cache : Optional [LlamaCache ]):
156+ """Set the cache.
157+
158+ Args:
159+ cache: The cache to set.
160+ """
161+ self ._cache = cache
162+
138163 def reset (self ):
139164 """Reset the model state."""
140165 self .last_n_tokens_data .extend (
@@ -245,6 +270,17 @@ def generate(
245270 The generated tokens.
246271 """
247272 assert self .ctx is not None
273+ ### HACK
274+ if (
275+ reset
276+ and self ._cache
277+ and len (self .tokens ) > 0
278+ and self .tokens == tokens [: len (self .tokens )]
279+ ):
280+ if self .verbose :
281+ print ("generate cache hit" , file = sys .stderr )
282+ reset = False
283+ ###
248284 if reset :
249285 self .reset ()
250286 while True :
@@ -361,13 +397,29 @@ def _create_completion(
361397 "logprobs is not supported for models created with logits_all=False"
362398 )
363399
400+ ### HACK
401+ reset : bool = True
402+ _prompt : bytes = prompt .encode ("utf-8" )
403+ _completion : bytes = b"" .join (self ._completion_bytes )
404+ if len (_completion ) and self ._cache and _prompt .startswith (_completion ):
405+ if self .verbose :
406+ print ("completion cache hit" , file = sys .stderr )
407+ reset = False
408+ _prompt = _prompt [len (_completion ) :]
409+ prompt_tokens = self .tokenize (b" " + _prompt )
410+ self ._completion_bytes .append (_prompt )
411+ else :
412+ self ._completion_bytes = [prompt .encode ("utf-8" )]
413+ ###
414+
364415 finish_reason = "length"
365416 for token in self .generate (
366417 prompt_tokens ,
367418 top_k = top_k ,
368419 top_p = top_p ,
369420 temp = temperature ,
370421 repeat_penalty = repeat_penalty ,
422+ reset = reset ,
371423 ):
372424 if token == llama_cpp .llama_token_eos ():
373425 text = self .detokenize (completion_tokens )
@@ -397,6 +449,9 @@ def _create_completion(
397449 break
398450 text = all_text [: len (all_text ) - longest ]
399451 returned_characters += len (text [start :])
452+ ### HACK
453+ self ._completion_bytes .append (text [start :])
454+ ###
400455 yield {
401456 "id" : completion_id ,
402457 "object" : "text_completion" ,
@@ -418,6 +473,9 @@ def _create_completion(
418473 break
419474
420475 if stream :
476+ ### HACK
477+ self ._completion_bytes .append (text [returned_characters :])
478+ ###
421479 yield {
422480 "id" : completion_id ,
423481 "object" : "text_completion" ,
@@ -434,13 +492,16 @@ def _create_completion(
434492 }
435493 return
436494
437- text = text .decode ("utf-8" )
495+ ### HACK
496+ self ._completion_bytes .append (text )
497+ ###
498+ text_str = text .decode ("utf-8" )
438499
439500 if echo :
440- text = prompt + text
501+ text_str = prompt + text_str
441502
442503 if suffix is not None :
443- text = text + suffix
504+ text_str = text_str + suffix
444505
445506 logprobs_or_none : Optional [CompletionLogprobs ] = None
446507 if logprobs is not None :
@@ -493,7 +554,7 @@ def _create_completion(
493554 "model" : self .model_path ,
494555 "choices" : [
495556 {
496- "text" : text ,
557+ "text" : text_str ,
497558 "index" : 0 ,
498559 "logprobs" : logprobs_or_none ,
499560 "finish_reason" : finish_reason ,
0 commit comments