44import time
55import math
66import multiprocessing
7- from typing import List , Optional , Union , Generator , Sequence , Iterator
7+ from typing import List , Optional , Union , Generator , Sequence , Iterator , Deque
88from collections import deque
99
1010from . import llama_cpp
@@ -20,6 +20,18 @@ class LlamaCache:
2020 pass
2121
2222
23+ class LlamaState :
24+ def __init__ (
25+ self ,
26+ eval_tokens : Deque [llama_cpp .llama_token ],
27+ eval_logits : Deque [List [float ]],
28+ llama_state ,
29+ ):
30+ self .eval_tokens = eval_tokens
31+ self .eval_logits = eval_logits
32+ self .llama_state = llama_state
33+
34+
2335class Llama :
2436 """High-level Python wrapper for a llama.cpp model."""
2537
@@ -85,8 +97,8 @@ def __init__(
8597
8698 self .last_n_tokens_size = last_n_tokens_size
8799 self .n_batch = min (n_ctx , n_batch )
88- self .eval_tokens : deque [llama_cpp .llama_token ] = deque (maxlen = n_ctx )
89- self .eval_logits : deque [List [float ]] = deque (maxlen = n_ctx )
100+ self .eval_tokens : Deque [llama_cpp .llama_token ] = deque (maxlen = n_ctx )
101+ self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx )
90102
91103 ### HACK: This is a hack to work around the fact that the llama.cpp API does not yet support
92104 ### saving and restoring state, this allows us to continue a completion if the last
@@ -204,7 +216,10 @@ def eval(self, tokens: Sequence[llama_cpp.llama_token]):
204216 cols = int (n_vocab )
205217 rows = n_tokens
206218 logits_view = llama_cpp .llama_get_logits (self .ctx )
207- logits = [[logits_view [i * cols + j ] for j in range (cols )] for i in range (rows )]
219+ logits = [
220+ [logits_view [i * cols + j ] for j in range (cols )]
221+ for i in range (rows )
222+ ]
208223 self .eval_logits .extend (logits )
209224
210225 def sample (
@@ -828,6 +843,26 @@ def __setstate__(self, state):
828843 verbose = state ["verbose" ],
829844 )
830845
846+ def save_state (self ) -> LlamaState :
847+ assert self .ctx is not None
848+ state_size = llama_cpp .llama_get_state_size (self .ctx )
849+ llama_state = (llama_cpp .c_uint8 * int (state_size ))()
850+ if llama_cpp .llama_copy_state_data (self .ctx , llama_state ) != state_size :
851+ raise RuntimeError ("Failed to copy llama state data" )
852+ return LlamaState (
853+ eval_tokens = self .eval_tokens .copy (),
854+ eval_logits = self .eval_logits .copy (),
855+ llama_state = llama_state ,
856+ )
857+
858+ def load_state (self , state : LlamaState ) -> None :
859+ assert self .ctx is not None
860+ self .eval_tokens = state .eval_tokens .copy ()
861+ self .eval_logits = state .eval_logits .copy ()
862+ state_size = llama_cpp .llama_get_state_size (self .ctx )
863+ if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
864+ raise RuntimeError ("Failed to set llama state data" )
865+
831866 @staticmethod
832867 def token_eos () -> llama_cpp .llama_token :
833868 """Return the end-of-sequence token."""
0 commit comments