1212
1313
1414class LlamaCache :
15- """Cache for a llama.cpp model.
15+ """Cache for a llama.cpp model."""
1616
17- NOTE: This implementation currently only tells the Llama class to avoid reprocessing bytes and continue from the last
18- completion. It does not actually cache the results."""
17+ def __init__ ( self ):
18+ self . cache_state : Dict [ Sequence [ llama_cpp . llama_token ], "LlamaState" ] = dict ()
1919
20- pass
20+ def __getitem__ (
21+ self , key : Sequence [llama_cpp .llama_token ]
22+ ) -> Optional ["LlamaState" ]:
23+ return self .cache_state .get (tuple (key ), None )
24+
25+ def __contains__ (self , key : Sequence [llama_cpp .llama_token ]) -> bool :
26+ return tuple (key ) in self .cache_state
27+
28+ def __setitem__ (self , key : Sequence [llama_cpp .llama_token ], value : "LlamaState" ):
29+ self .cache_state = dict () # NOTE: Currently limit to one cache entry.
30+ self .cache_state [tuple (key )] = value
2131
2232
2333class LlamaState :
@@ -100,13 +110,7 @@ def __init__(
100110 self .eval_tokens : Deque [llama_cpp .llama_token ] = deque (maxlen = n_ctx )
101111 self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx )
102112
103- ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
104- ### saving and restoring state, this allows us to continue a completion if the last
105- ### completion_bytes is a prefix to the prompt passed in. However this is actually incorrect
106- ### because it does not take into account stop tokens which have been processed by the model.
107- self ._completion_bytes : List [bytes ] = []
108- self ._cache : Optional [LlamaCache ] = None
109- ###
113+ self .cache : Optional [LlamaCache ] = None
110114
111115 self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
112116
@@ -182,7 +186,7 @@ def set_cache(self, cache: Optional[LlamaCache]):
182186 Args:
183187 cache: The cache to set.
184188 """
185- self ._cache = cache
189+ self .cache = cache
186190
187191 def reset (self ):
188192 """Reset the model state."""
@@ -287,18 +291,17 @@ def generate(
287291 The generated tokens.
288292 """
289293 assert self .ctx is not None
290- ### HACK
294+
291295 if (
292296 reset
293- and self ._cache
294297 and len (self .eval_tokens ) > 0
295298 and self .eval_tokens == tokens [: len (self .eval_tokens )]
296299 ):
297300 if self .verbose :
298301 print ("generate cache hit" , file = sys .stderr )
299302 reset = False
300303 tokens = tokens [len (self .eval_tokens ) :]
301- ###
304+
302305 if reset :
303306 self .reset ()
304307 while True :
@@ -415,20 +418,10 @@ def _create_completion(
415418 "logprobs is not supported for models created with logits_all=False"
416419 )
417420
418- ### HACK
419- reset : bool = True
420- _prompt : bytes = prompt .encode ("utf-8" )
421- _completion : bytes = b"" .join (self ._completion_bytes )
422- if len (_completion ) and self ._cache and _prompt .startswith (_completion ):
421+ if self .cache and prompt_tokens in self .cache :
423422 if self .verbose :
424- print ("completion cache hit" , file = sys .stderr )
425- reset = False
426- _prompt = _prompt [len (_completion ) :]
427- prompt_tokens = self .tokenize (b" " + _prompt )
428- self ._completion_bytes .append (_prompt )
429- else :
430- self ._completion_bytes = [prompt .encode ("utf-8" )]
431- ###
423+ print ("cache hit" , file = sys .stderr )
424+ self .load_state (self .cache [prompt_tokens ])
432425
433426 finish_reason = "length"
434427 for token in self .generate (
@@ -437,12 +430,16 @@ def _create_completion(
437430 top_p = top_p ,
438431 temp = temperature ,
439432 repeat_penalty = repeat_penalty ,
440- reset = reset ,
441433 ):
442434 if token == llama_cpp .llama_token_eos ():
443435 text = self .detokenize (completion_tokens )
444436 finish_reason = "stop"
445437 break
438+
439+ if self .cache and len (completion_tokens ) == 0 :
440+ if prompt_tokens not in self .cache :
441+ self .cache [prompt_tokens ] = self .save_state ()
442+
446443 completion_tokens .append (token )
447444
448445 all_text = self .detokenize (completion_tokens )
@@ -467,9 +464,6 @@ def _create_completion(
467464 break
468465 text = all_text [: len (all_text ) - longest ]
469466 returned_characters += len (text [start :])
470- ### HACK
471- self ._completion_bytes .append (text [start :])
472- ###
473467 yield {
474468 "id" : completion_id ,
475469 "object" : "text_completion" ,
@@ -491,9 +485,6 @@ def _create_completion(
491485 break
492486
493487 if stream :
494- ### HACK
495- self ._completion_bytes .append (text [returned_characters :])
496- ###
497488 yield {
498489 "id" : completion_id ,
499490 "object" : "text_completion" ,
@@ -510,9 +501,6 @@ def _create_completion(
510501 }
511502 return
512503
513- ### HACK
514- self ._completion_bytes .append (text )
515- ###
516504 text_str = text .decode ("utf-8" )
517505
518506 if echo :
0 commit comments