@@ -141,7 +141,7 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141 if _key is None :
142142 raise KeyError ("Key not found" )
143143 value : "LlamaState" = self .cache .pop (_key ) # type: ignore
144- # NOTE: This puts an integer as key in cache, which breaks,
144+ # NOTE: This puts an integer as key in cache, which breaks,
145145 # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146146 # self.cache.push(_key, side="front") # type: ignore
147147 return value
@@ -166,17 +166,15 @@ def __setitem__(self, key: Sequence[int], value: "LlamaState"):
166166class LlamaState :
167167 def __init__ (
168168 self ,
169- eval_tokens : Deque [int ],
170- eval_logits : Deque [List [float ]],
171169 input_ids : npt .NDArray [np .intc ],
172170 scores : npt .NDArray [np .single ],
171+ n_tokens : int ,
173172 llama_state : bytes ,
174173 llama_state_size : int ,
175174 ):
176- self .eval_tokens = eval_tokens
177- self .eval_logits = eval_logits
178175 self .input_ids = input_ids
179176 self .scores = scores
177+ self .n_tokens = n_tokens
180178 self .llama_state = llama_state
181179 self .llama_state_size = llama_state_size
182180
@@ -267,8 +265,6 @@ def __init__(
267265
268266 self .last_n_tokens_size = last_n_tokens_size
269267 self .n_batch = min (n_ctx , n_batch )
270- self .eval_tokens : Deque [int ] = deque (maxlen = n_ctx )
271- self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx if logits_all else 1 )
272268
273269 self .cache : Optional [BaseLlamaCache ] = None
274270
@@ -329,8 +325,30 @@ def __init__(
329325 self ._token_nl = Llama .token_nl ()
330326 self ._token_eos = Llama .token_eos ()
331327
332- self ._input_ids = np .array ([], dtype = np .intc )
333- self ._scores : npt .NDArray [np .single ] = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
328+ self .n_tokens = 0
329+ self .input_ids : npt .NDArray [np .intc ] = np .ndarray ((n_ctx ,), dtype = np .intc )
330+ self .scores : npt .NDArray [np .single ] = np .ndarray (
331+ (n_ctx , self ._n_vocab ), dtype = np .single
332+ )
333+
334+ @property
335+ def _input_ids (self ) -> npt .NDArray [np .intc ]:
336+ return self .input_ids [: self .n_tokens ]
337+
338+ @property
339+ def _scores (self ) -> npt .NDArray [np .single ]:
340+ return self .scores [: self .n_tokens , :]
341+
342+ @property
343+ def eval_tokens (self ) -> Deque [int ]:
344+ return deque (self .input_ids [: self .n_tokens ].tolist (), maxlen = self ._n_ctx )
345+
346+ @property
347+ def eval_logits (self ) -> Deque [List [float ]]:
348+ return deque (
349+ self .scores [: self .n_tokens , :].tolist (),
350+ maxlen = self ._n_ctx if self .params .logits_all else 1 ,
351+ )
334352
335353 def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
336354 """Tokenize a string.
@@ -397,10 +415,7 @@ def set_cache(self, cache: Optional[BaseLlamaCache]):
397415
398416 def reset (self ):
399417 """Reset the model state."""
400- self .eval_tokens .clear ()
401- self .eval_logits .clear ()
402- self ._input_ids = np .array ([], dtype = np .intc )
403- self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
418+ self .n_tokens = 0
404419
405420 def eval (self , tokens : Sequence [int ]):
406421 """Evaluate a list of tokens.
@@ -410,7 +425,6 @@ def eval(self, tokens: Sequence[int]):
410425 """
411426 assert self .ctx is not None
412427 n_ctx = self ._n_ctx
413- scores : List [npt .NDArray [np .single ]] = []
414428 for i in range (0 , len (tokens ), self .n_batch ):
415429 batch = tokens [i : min (len (tokens ), i + self .n_batch )]
416430 n_past = min (n_ctx - len (batch ), len (self ._input_ids ))
@@ -425,19 +439,16 @@ def eval(self, tokens: Sequence[int]):
425439 if return_code != 0 :
426440 raise RuntimeError (f"llama_eval returned { return_code } " )
427441 # Save tokens
428- self .eval_tokens .extend (batch )
429- self ._input_ids : npt .NDArray [np .intc ] = np .concatenate (
430- (self ._input_ids , np .array (batch , dtype = np .intc )), axis = 0
431- )
442+ self .input_ids [self .n_tokens : self .n_tokens + n_tokens ] = batch
432443 # Save logits
433444 rows = n_tokens if self .params .logits_all else 1
434445 n_vocab = self ._n_vocab
435446 cols = n_vocab
436447 logits_view = llama_cpp .llama_get_logits (self .ctx )
437448 logits = [logits_view [i * cols : (i + 1 ) * cols ] for i in range (rows )]
438- self .eval_logits . extend ( logits )
439- scores . append ( np . array ( logits , dtype = np . single ))
440- self ._scores = np . concatenate ( scores )
449+ self .scores [ self . n_tokens : self . n_tokens + n_tokens , :] = logits
450+ # Update n_tokens
451+ self .n_tokens += n_tokens
441452
442453 def _sample (
443454 self ,
@@ -457,8 +468,7 @@ def _sample(
457468 logits_processor : Optional [LogitsProcessorList ] = None ,
458469 ):
459470 assert self .ctx is not None
460- assert len (self .eval_logits ) > 0
461- assert self ._scores .shape [0 ] > 0
471+ assert self .n_tokens > 0
462472 n_vocab = self ._n_vocab
463473 n_ctx = self ._n_ctx
464474 top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
@@ -475,7 +485,6 @@ def _sample(
475485 dtype = np .single ,
476486 )
477487 self ._scores [- 1 , :] = logits
478- self .eval_logits [- 1 ] = logits .tolist ()
479488
480489 nl_logit = logits [self ._token_nl ]
481490 candidates = self ._candidates
@@ -672,14 +681,7 @@ def generate(
672681 print ("Llama.generate: prefix-match hit" , file = sys .stderr )
673682 reset = False
674683 tokens = tokens [longest_prefix :]
675- self ._input_ids = self ._input_ids [:longest_prefix ]
676- self ._scores = self ._scores [:longest_prefix , :]
677- for _ in range (len (self .eval_tokens ) - longest_prefix ):
678- self .eval_tokens .pop ()
679- try :
680- self .eval_logits .pop ()
681- except IndexError :
682- pass
684+ self .n_tokens = longest_prefix
683685
684686 if reset :
685687 self .reset ()
@@ -819,7 +821,9 @@ def _create_completion(
819821 llama_cpp .llama_reset_timings (self .ctx )
820822
821823 if len (prompt_tokens ) > self ._n_ctx :
822- raise ValueError (f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { self ._n_ctx } " )
824+ raise ValueError (
825+ f"Requested tokens ({ len (prompt_tokens )} ) exceed context window of { self ._n_ctx } "
826+ )
823827
824828 # Truncate max_tokens if requested tokens would exceed the context window
825829 max_tokens = (
@@ -1513,22 +1517,20 @@ def save_state(self) -> LlamaState:
15131517 file = sys .stderr ,
15141518 )
15151519 return LlamaState (
1516- eval_tokens = self .eval_tokens .copy (),
1517- eval_logits = self .eval_logits .copy (),
1518- scores = self ._scores .copy (),
1519- input_ids = self ._input_ids .copy (),
1520+ scores = self .scores .copy (),
1521+ input_ids = self .input_ids .copy (),
1522+ n_tokens = self .n_tokens ,
15201523 llama_state = bytes (llama_state_compact ),
15211524 llama_state_size = n_bytes ,
15221525 )
15231526
15241527 def load_state (self , state : LlamaState ) -> None :
15251528 assert self .ctx is not None
1526- self .eval_tokens = state .eval_tokens .copy ()
1527- self .eval_logits = state .eval_logits .copy ()
1528- self ._scores = state .scores .copy ()
1529- self ._input_ids = state .input_ids .copy ()
1529+ self .scores = state .scores .copy ()
1530+ self .input_ids = state .input_ids .copy ()
1531+ self .n_tokens = state .n_tokens
15301532 state_size = state .llama_state_size
1531- LLamaStateArrayType = ( llama_cpp .c_uint8 * state_size )
1533+ LLamaStateArrayType = llama_cpp .c_uint8 * state_size
15321534 llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
15331535
15341536 if llama_cpp .llama_set_state_data (self .ctx , llama_state ) != state_size :
0 commit comments