@@ -198,6 +198,8 @@ def __init__(
198198 sorted = sorted ,
199199 )
200200 self ._candidates = candidates
201+ self ._token_nl = Llama .token_nl ()
202+ self ._token_eos = Llama .token_eos ()
201203
202204 def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
203205 """Tokenize a string.
@@ -327,7 +329,7 @@ def _sample(
327329 else last_n_tokens_size
328330 )
329331 logits = self .eval_logits [- 1 ]
330- nl_logit = logits [Llama . token_nl () ]
332+ nl_logit = logits [self . _token_nl ]
331333 candidates = self ._candidates
332334 for i , logit in enumerate (logits ):
333335 candidates .data [i ].id = llama_cpp .llama_token (i )
@@ -351,7 +353,7 @@ def _sample(
351353 alpha_presence = presence_penalty ,
352354 )
353355 if not penalize_nl :
354- candidates .data [Llama . token_nl () ].logit = llama_cpp .c_float (nl_logit )
356+ candidates .data [self . _token_nl ].logit = llama_cpp .c_float (nl_logit )
355357 if temp .value == 0.0 :
356358 return llama_cpp .llama_sample_token_greedy (
357359 ctx = self .ctx ,
@@ -688,7 +690,7 @@ def _create_completion(
688690 presence_penalty = presence_penalty ,
689691 repeat_penalty = repeat_penalty ,
690692 ):
691- if token == Llama . token_eos () :
693+ if token == self . _token_eos :
692694 text = self .detokenize (completion_tokens )
693695 finish_reason = "stop"
694696 break
0 commit comments