2020from . import llama_cpp
2121from .llama_types import *
2222
23+ import numpy as np
24+ import numpy .typing as npt
25+
2326
2427class LlamaCache :
2528 """Cache for a llama.cpp model."""
@@ -73,11 +76,15 @@ def __init__(
7376 self ,
7477 eval_tokens : Deque [int ],
7578 eval_logits : Deque [List [float ]],
79+ input_ids : npt .NDArray [np .intc ],
80+ scores : npt .NDArray [np .single ],
7681 llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
7782 llama_state_size : int ,
7883 ):
7984 self .eval_tokens = eval_tokens
8085 self .eval_logits = eval_logits
86+ self .input_ids = input_ids
87+ self .scores = scores
8188 self .llama_state = llama_state
8289 self .llama_state_size = llama_state_size
8390
@@ -207,27 +214,27 @@ def __init__(
207214
208215 self ._n_vocab = self .n_vocab ()
209216 self ._n_ctx = self .n_ctx ()
210- data = (llama_cpp .llama_token_data * self ._n_vocab )(
211- * [
212- llama_cpp .llama_token_data (
213- id = llama_cpp .llama_token (i ),
214- logit = llama_cpp .c_float (0.0 ),
215- p = llama_cpp .c_float (0.0 ),
216- )
217- for i in range (self ._n_vocab )
218- ]
219- )
220217 size = llama_cpp .c_size_t (self ._n_vocab )
221- sorted = False
218+ sorted = llama_cpp .c_bool (False )
219+ self ._candidates_data = np .array (
220+ [],
221+ dtype = np .dtype (
222+ [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )], align = True
223+ ),
224+ )
225+ self ._candidates_data .resize (3 , self ._n_vocab )
222226 candidates = llama_cpp .llama_token_data_array (
223- data = data ,
227+ data = self . _candidates_data . ctypes . data_as ( llama_cpp . llama_token_data_p ) ,
224228 size = size ,
225229 sorted = sorted ,
226230 )
227231 self ._candidates = candidates
228232 self ._token_nl = Llama .token_nl ()
229233 self ._token_eos = Llama .token_eos ()
230234
235+ self ._input_ids = np .array ([], dtype = np .intc )
236+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
237+
231238 def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
232239 """Tokenize a string.
233240
@@ -295,6 +302,8 @@ def reset(self):
295302 """Reset the model state."""
296303 self .eval_tokens .clear ()
297304 self .eval_logits .clear ()
305+ self ._input_ids = np .array ([], dtype = np .intc )
306+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
298307
299308 def eval (self , tokens : Sequence [int ]):
300309 """Evaluate a list of tokens.
@@ -306,7 +315,7 @@ def eval(self, tokens: Sequence[int]):
306315 n_ctx = self ._n_ctx
307316 for i in range (0 , len (tokens ), self .n_batch ):
308317 batch = tokens [i : min (len (tokens ), i + self .n_batch )]
309- n_past = min (n_ctx - len (batch ), len (self .eval_tokens ))
318+ n_past = min (n_ctx - len (batch ), len (self ._input_ids ))
310319 n_tokens = len (batch )
311320 return_code = llama_cpp .llama_eval (
312321 ctx = self .ctx ,
@@ -319,13 +328,19 @@ def eval(self, tokens: Sequence[int]):
319328 raise RuntimeError (f"llama_eval returned { return_code } " )
320329 # Save tokens
321330 self .eval_tokens .extend (batch )
331+ self ._input_ids : npt .NDArray [np .intc ] = np .concatenate (
332+ (self ._input_ids , np .array (batch , dtype = np .intc )), axis = 0
333+ )
322334 # Save logits
323335 rows = n_tokens if self .params .logits_all else 1
324336 n_vocab = self ._n_vocab
325337 cols = n_vocab
326338 logits_view = llama_cpp .llama_get_logits (self .ctx )
327339 logits = [logits_view [i * cols : (i + 1 ) * cols ] for i in range (rows )]
328340 self .eval_logits .extend (logits )
341+ self ._scores : npt .NDArray [np .single ] = np .concatenate (
342+ (self ._scores , np .array (logits , dtype = np .single )), axis = 0
343+ )
329344
330345 def _sample (
331346 self ,
@@ -346,6 +361,7 @@ def _sample(
346361 ):
347362 assert self .ctx is not None
348363 assert len (self .eval_logits ) > 0
364+ assert self ._scores .shape [0 ] > 0
349365 n_vocab = self ._n_vocab
350366 n_ctx = self ._n_ctx
351367 top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
@@ -354,18 +370,23 @@ def _sample(
354370 if last_n_tokens_size .value < 0
355371 else last_n_tokens_size
356372 )
357- logits = self .eval_logits [- 1 ]
373+ logits : npt . NDArray [ np . single ] = self ._scores [- 1 , : ]
358374
359375 if logits_processor is not None :
360- logits = logits_processor (list (self .eval_tokens ), logits )
361- self .eval_logits [- 1 ] = logits
376+ logits = np .array (
377+ logits_processor (self ._input_ids .tolist (), logits .tolist ()),
378+ dtype = np .single ,
379+ )
380+ self ._scores [- 1 , :] = logits
381+ self .eval_logits [- 1 ] = logits .tolist ()
362382
363383 nl_logit = logits [self ._token_nl ]
364384 candidates = self ._candidates
365- for i , logit in enumerate (logits ):
366- candidates .data [i ].id = llama_cpp .llama_token (i )
367- candidates .data [i ].logit = llama_cpp .c_float (logit )
368- candidates .data [i ].p = llama_cpp .c_float (0.0 )
385+ candidates_data = self ._candidates_data
386+ candidates_data ["id" ] = np .arange (n_vocab , dtype = np .intc ) # type: ignore
387+ candidates_data ["logit" ] = logits
388+ candidates_data ["p" ] = np .zeros (n_vocab , dtype = np .single )
389+ candidates .data = candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p )
369390 candidates .sorted = llama_cpp .c_bool (False )
370391 candidates .size = llama_cpp .c_size_t (n_vocab )
371392 llama_cpp .llama_sample_repetition_penalty (
@@ -483,8 +504,8 @@ def sample(
483504 """
484505 assert self .ctx is not None
485506 last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
486- 0 , self .last_n_tokens_size - len (self .eval_tokens )
487- ) + list ( self .eval_tokens ) [- self .last_n_tokens_size :]
507+ 0 , self .last_n_tokens_size - len (self ._input_ids )
508+ ) + self ._input_ids [- self .last_n_tokens_size :]. tolist ()
488509 return self ._sample (
489510 last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
490511 * last_n_tokens_data
@@ -542,9 +563,9 @@ def generate(
542563 """
543564 assert self .ctx is not None
544565
545- if reset and len (self .eval_tokens ) > 0 :
566+ if reset and len (self ._input_ids ) > 0 :
546567 longest_prefix = 0
547- for a , b in zip (self .eval_tokens , tokens [:- 1 ]):
568+ for a , b in zip (self ._input_ids , tokens [:- 1 ]):
548569 if a == b :
549570 longest_prefix += 1
550571 else :
@@ -554,6 +575,8 @@ def generate(
554575 print ("Llama.generate: prefix-match hit" , file = sys .stderr )
555576 reset = False
556577 tokens = tokens [longest_prefix :]
578+ self ._input_ids = self ._input_ids [:longest_prefix ]
579+ self ._scores = self ._scores [:longest_prefix , :]
557580 for _ in range (len (self .eval_tokens ) - longest_prefix ):
558581 self .eval_tokens .pop ()
559582 try :
@@ -580,7 +603,7 @@ def generate(
580603 logits_processor = logits_processor ,
581604 )
582605 if stopping_criteria is not None and stopping_criteria (
583- list ( self .eval_tokens ), self .eval_logits [- 1 ]
606+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
584607 ):
585608 return
586609 tokens_or_none = yield token
@@ -715,10 +738,10 @@ def _create_completion(
715738 try :
716739 cache_item = self .cache [prompt_tokens ]
717740 cache_prefix_len = Llama .longest_token_prefix (
718- cache_item .eval_tokens , prompt_tokens
741+ cache_item .input_ids . tolist () , prompt_tokens
719742 )
720743 eval_prefix_len = Llama .longest_token_prefix (
721- self .eval_tokens , prompt_tokens
744+ self ._input_ids . tolist () , prompt_tokens
722745 )
723746 if cache_prefix_len > eval_prefix_len :
724747 self .load_state (cache_item )
@@ -807,7 +830,7 @@ def _create_completion(
807830 self .detokenize (completion_tokens [:returned_tokens ])
808831 )
809832 token_offset = len (prompt_tokens ) + returned_tokens
810- logits = self .eval_logits [token_offset - 1 ]
833+ logits = self ._scores [token_offset - 1 , :]. tolist ()
811834 current_logprobs = Llama .logits_to_logprobs (logits )
812835 sorted_logprobs = list (
813836 sorted (
@@ -856,7 +879,7 @@ def _create_completion(
856879 break
857880
858881 if stopping_criteria is not None and stopping_criteria (
859- list ( self .eval_tokens ), self .eval_logits [- 1 ]
882+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
860883 ):
861884 text = self .detokenize (completion_tokens )
862885 finish_reason = "stop"
@@ -886,7 +909,7 @@ def _create_completion(
886909 self .detokenize (completion_tokens [:returned_tokens ])
887910 )
888911 token_offset = len (prompt_tokens ) + returned_tokens - 1
889- logits = self .eval_logits [token_offset ]
912+ logits = self ._scores [token_offset , :]. tolist ()
890913 current_logprobs = Llama .logits_to_logprobs (logits )
891914 sorted_logprobs = list (
892915 sorted (
@@ -988,8 +1011,7 @@ def _create_completion(
9881011 for token in all_tokens
9891012 ]
9901013 all_logprobs = [
991- Llama .logits_to_logprobs (list (map (float , row )))
992- for row in self .eval_logits
1014+ Llama .logits_to_logprobs (row .tolist ()) for row in self ._scores
9931015 ][token_offset :]
9941016 for token , token_str , logprobs_token in zip (
9951017 all_tokens , all_token_strs , all_logprobs
@@ -1373,6 +1395,8 @@ def save_state(self) -> LlamaState:
13731395 return LlamaState (
13741396 eval_tokens = self .eval_tokens .copy (),
13751397 eval_logits = self .eval_logits .copy (),
1398+ scores = self ._scores .copy (),
1399+ input_ids = self ._input_ids .copy (),
13761400 llama_state = llama_state_compact ,
13771401 llama_state_size = n_bytes ,
13781402 )
@@ -1381,6 +1405,8 @@ def load_state(self, state: LlamaState) -> None:
13811405 assert self .ctx is not None
13821406 self .eval_tokens = state .eval_tokens .copy ()
13831407 self .eval_logits = state .eval_logits .copy ()
1408+ self ._scores = state .scores .copy ()
1409+ self ._input_ids = state .input_ids .copy ()
13841410 state_size = state .llama_state_size
13851411 if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
13861412 raise RuntimeError ("Failed to set llama state data" )
0 commit comments