@@ -15,9 +15,7 @@ class LlamaCache:
1515 """Cache for a llama.cpp model."""
1616
1717 def __init__ (self , capacity_bytes : int = (2 << 30 )):
18- self .cache_state : OrderedDict [
19- Tuple [llama_cpp .llama_token , ...], "LlamaState"
20- ] = OrderedDict ()
18+ self .cache_state : OrderedDict [Tuple [int , ...], "LlamaState" ] = OrderedDict ()
2119 self .capacity_bytes = capacity_bytes
2220
2321 @property
@@ -26,8 +24,8 @@ def cache_size(self):
2624
2725 def _find_longest_prefix_key (
2826 self ,
29- key : Tuple [llama_cpp . llama_token , ...],
30- ) -> Optional [Tuple [llama_cpp . llama_token , ...]]:
27+ key : Tuple [int , ...],
28+ ) -> Optional [Tuple [int , ...]]:
3129 min_len = 0
3230 min_key = None
3331 keys = (
@@ -39,7 +37,7 @@ def _find_longest_prefix_key(
3937 min_key = k
4038 return min_key
4139
42- def __getitem__ (self , key : Sequence [llama_cpp . llama_token ]) -> "LlamaState" :
40+ def __getitem__ (self , key : Sequence [int ]) -> "LlamaState" :
4341 key = tuple (key )
4442 _key = self ._find_longest_prefix_key (key )
4543 if _key is None :
@@ -48,10 +46,10 @@ def __getitem__(self, key: Sequence[llama_cpp.llama_token]) -> "LlamaState":
4846 self .cache_state .move_to_end (_key )
4947 return value
5048
51- def __contains__ (self , key : Sequence [llama_cpp . llama_token ]) -> bool :
49+ def __contains__ (self , key : Sequence [int ]) -> bool :
5250 return self ._find_longest_prefix_key (tuple (key )) is not None
5351
54- def __setitem__ (self , key : Sequence [llama_cpp . llama_token ], value : "LlamaState" ):
52+ def __setitem__ (self , key : Sequence [int ], value : "LlamaState" ):
5553 key = tuple (key )
5654 if key in self .cache_state :
5755 del self .cache_state [key ]
@@ -63,7 +61,7 @@ def __setitem__(self, key: Sequence[llama_cpp.llama_token], value: "LlamaState")
6361class LlamaState :
6462 def __init__ (
6563 self ,
66- eval_tokens : Deque [llama_cpp . llama_token ],
64+ eval_tokens : Deque [int ],
6765 eval_logits : Deque [List [float ]],
6866 llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
6967 llama_state_size : int ,
@@ -141,7 +139,7 @@ def __init__(
141139
142140 self .last_n_tokens_size = last_n_tokens_size
143141 self .n_batch = min (n_ctx , n_batch )
144- self .eval_tokens : Deque [llama_cpp . llama_token ] = deque (maxlen = n_ctx )
142+ self .eval_tokens : Deque [int ] = deque (maxlen = n_ctx )
145143 self .eval_logits : Deque [List [float ]] = deque (maxlen = n_ctx if logits_all else 1 )
146144
147145 self .cache : Optional [LlamaCache ] = None
@@ -176,9 +174,7 @@ def __init__(
176174 if self .verbose :
177175 print (llama_cpp .llama_print_system_info ().decode ("utf-8" ), file = sys .stderr )
178176
179- def tokenize (
180- self , text : bytes , add_bos : bool = True
181- ) -> List [llama_cpp .llama_token ]:
177+ def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
182178 """Tokenize a string.
183179
184180 Args:
@@ -197,7 +193,7 @@ def tokenize(
197193 self .ctx ,
198194 text ,
199195 tokens ,
200- n_ctx ,
196+ llama_cpp . c_int ( n_ctx ) ,
201197 llama_cpp .c_bool (add_bos ),
202198 )
203199 if int (n_tokens ) < 0 :
@@ -216,7 +212,7 @@ def tokenize(
216212 )
217213 return list (tokens [:n_tokens ])
218214
219- def detokenize (self , tokens : List [llama_cpp . llama_token ]) -> bytes :
215+ def detokenize (self , tokens : List [int ]) -> bytes :
220216 """Detokenize a list of tokens.
221217
222218 Args:
@@ -228,7 +224,9 @@ def detokenize(self, tokens: List[llama_cpp.llama_token]) -> bytes:
228224 assert self .ctx is not None
229225 output = b""
230226 for token in tokens :
231- output += llama_cpp .llama_token_to_str (self .ctx , token )
227+ output += llama_cpp .llama_token_to_str (
228+ self .ctx , llama_cpp .llama_token (token )
229+ )
232230 return output
233231
234232 def set_cache (self , cache : Optional [LlamaCache ]):
@@ -244,7 +242,7 @@ def reset(self):
244242 self .eval_tokens .clear ()
245243 self .eval_logits .clear ()
246244
247- def eval (self , tokens : Sequence [llama_cpp . llama_token ]):
245+ def eval (self , tokens : Sequence [int ]):
248246 """Evaluate a list of tokens.
249247
250248 Args:
@@ -458,7 +456,7 @@ def sample(
458456
459457 def generate (
460458 self ,
461- tokens : Sequence [llama_cpp . llama_token ],
459+ tokens : Sequence [int ],
462460 top_k : int = 40 ,
463461 top_p : float = 0.95 ,
464462 temp : float = 0.80 ,
@@ -470,9 +468,7 @@ def generate(
470468 mirostat_mode : int = 0 ,
471469 mirostat_tau : float = 5.0 ,
472470 mirostat_eta : float = 0.1 ,
473- ) -> Generator [
474- llama_cpp .llama_token , Optional [Sequence [llama_cpp .llama_token ]], None
475- ]:
471+ ) -> Generator [int , Optional [Sequence [int ]], None ]:
476472 """Create a generator of tokens from a prompt.
477473
478474 Examples:
@@ -617,14 +613,14 @@ def _create_completion(
617613 assert self .ctx is not None
618614 completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
619615 created : int = int (time .time ())
620- completion_tokens : List [llama_cpp . llama_token ] = []
616+ completion_tokens : List [int ] = []
621617 # Add blank space to start of prompt to match OG llama tokenizer
622- prompt_tokens : List [llama_cpp .llama_token ] = self .tokenize (
623- b" " + prompt .encode ("utf-8" )
624- )
618+ prompt_tokens : List [int ] = self .tokenize (b" " + prompt .encode ("utf-8" ))
625619 text : bytes = b""
626620 returned_tokens : int = 0
627- stop = stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
621+ stop = (
622+ stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
623+ )
628624 model_name : str = model if model is not None else self .model_path
629625
630626 if self .verbose :
@@ -724,7 +720,9 @@ def _create_completion(
724720 for token in remaining_tokens :
725721 token_end_position += len (self .detokenize ([token ]))
726722 # Check if stop sequence is in the token
727- if token_end_position >= (remaining_length - first_stop_position - 1 ):
723+ if token_end_position >= (
724+ remaining_length - first_stop_position - 1
725+ ):
728726 break
729727 logprobs_or_none : Optional [CompletionLogprobs ] = None
730728 if logprobs is not None :
@@ -744,7 +742,7 @@ def _create_completion(
744742 )
745743 )
746744 top_logprob = {
747- self .detokenize ([llama_cpp . llama_token ( i ) ]).decode (
745+ self .detokenize ([i ]).decode (
748746 "utf-8" , errors = "ignore"
749747 ): logprob
750748 for logprob , i in sorted_logprobs [:logprobs ]
@@ -822,9 +820,7 @@ def _create_completion(
822820 )
823821 )
824822 top_logprob = {
825- self .detokenize ([llama_cpp .llama_token (i )]).decode (
826- "utf-8" , errors = "ignore"
827- ): logprob
823+ self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
828824 for logprob , i in sorted_logprobs [:logprobs ]
829825 }
830826 top_logprob .update ({token_str : current_logprobs [int (token )]})
@@ -924,9 +920,7 @@ def _create_completion(
924920 )
925921 token_logprobs .append (sorted_logprobs [int (token )][0 ])
926922 top_logprob : Optional [Dict [str , float ]] = {
927- self .detokenize ([llama_cpp .llama_token (i )]).decode (
928- "utf-8" , errors = "ignore"
929- ): logprob
923+ self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
930924 for logprob , i in sorted_logprobs [:logprobs ]
931925 }
932926 top_logprob .update ({token_str : logprobs_token [int (token )]})
@@ -1188,7 +1182,9 @@ def create_chat_completion(
11881182 Returns:
11891183 Generated chat completion or a stream of chat completion chunks.
11901184 """
1191- stop = stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
1185+ stop = (
1186+ stop if isinstance (stop , list ) else [stop ] if isinstance (stop , str ) else []
1187+ )
11921188 chat_history = "" .join (
11931189 f'### { "Human" if message ["role" ] == "user" else "Assistant" } :{ message ["content" ]} '
11941190 for message in messages
@@ -1296,17 +1292,17 @@ def load_state(self, state: LlamaState) -> None:
12961292 raise RuntimeError ("Failed to set llama state data" )
12971293
12981294 @staticmethod
1299- def token_eos () -> llama_cpp . llama_token :
1295+ def token_eos () -> int :
13001296 """Return the end-of-sequence token."""
13011297 return llama_cpp .llama_token_eos ()
13021298
13031299 @staticmethod
1304- def token_bos () -> llama_cpp . llama_token :
1300+ def token_bos () -> int :
13051301 """Return the beginning-of-sequence token."""
13061302 return llama_cpp .llama_token_bos ()
13071303
13081304 @staticmethod
1309- def token_nl () -> llama_cpp . llama_token :
1305+ def token_nl () -> int :
13101306 """Return the newline token."""
13111307 return llama_cpp .llama_token_nl ()
13121308
@@ -1317,9 +1313,7 @@ def logits_to_logprobs(logits: List[float]) -> List[float]:
13171313 return [math .log (x / sum_exps ) for x in exps ]
13181314
13191315 @staticmethod
1320- def longest_token_prefix (
1321- a : Sequence [llama_cpp .llama_token ], b : Sequence [llama_cpp .llama_token ]
1322- ):
1316+ def longest_token_prefix (a : Sequence [int ], b : Sequence [int ]):
13231317 longest_prefix = 0
13241318 for _a , _b in zip (a , b ):
13251319 if _a == _b :
0 commit comments