@@ -545,13 +545,12 @@ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
545545class _LlamaTokenDataArray :
546546 def __init__ (self , * , n_vocab : int ):
547547 self .n_vocab = n_vocab
548- self .candidates_data = np .array (
549- [] ,
548+ self .candidates_data = np .recarray (
549+ ( self . n_vocab ,) ,
550550 dtype = np .dtype (
551551 [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )], align = True
552552 ),
553553 )
554- self .candidates_data .resize (3 , self .n_vocab , refcheck = False )
555554 self .candidates = llama_cpp .llama_token_data_array (
556555 data = self .candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p ),
557556 size = self .n_vocab ,
@@ -561,14 +560,11 @@ def __init__(self, *, n_vocab: int):
561560 self .default_candidates_data_p = np .zeros (self .n_vocab , dtype = np .single )
562561
563562 def copy_logits (self , logits : npt .NDArray [np .single ]):
564- self .candidates_data ["id" ][:] = self .default_candidates_data_id
565- self .candidates_data ["logit" ][:] = logits
566- self .candidates_data ["p" ][:] = self .default_candidates_data_p
567- self .candidates .data = self .candidates_data .ctypes .data_as (
568- llama_cpp .llama_token_data_p
569- )
570- self .candidates .sorted = ctypes .c_bool (False )
571- self .candidates .size = ctypes .c_size_t (self .n_vocab )
563+ self .candidates_data .id [:] = self .default_candidates_data_id
564+ self .candidates_data .logit [:] = logits
565+ self .candidates_data .p [:] = self .default_candidates_data_p
566+ self .candidates .sorted = False
567+ self .candidates .size = self .n_vocab
572568
573569
574570# Python wrappers over common/common
@@ -759,14 +755,14 @@ def sample(
759755 self .params .penalty_present ,
760756 )
761757 if not self .params .penalize_nl :
762- token_data_array .candidates_data [ " logit" ] [nl_token ] = nl_logit
758+ token_data_array .candidates_data . logit [nl_token ] = nl_logit
763759
764760 if self .grammar is not None :
765761 ctx_main .sample_grammar (token_data_array , self .grammar )
766762
767763 if self .params .temp < 0 :
768764 ctx_main .sample_softmax (token_data_array )
769- id = token_data_array .candidates_data [ "id" ] [0 ]
765+ id = token_data_array .candidates_data . id [0 ]
770766 elif self .params .temp == 0 :
771767 id = ctx_main .sample_token_greedy (token_data_array )
772768 else :
0 commit comments