@@ -240,11 +240,11 @@ class llama_token_data_array(Structure):
240240# typedef struct llama_batch {
241241# int32_t n_tokens;
242242
243- # llama_token * token;
244- # float * embd;
245- # llama_pos * pos;
246- # llama_seq_id * seq_id;
247- # int8_t * logits;
243+ # llama_token * token;
244+ # float * embd;
245+ # llama_pos * pos;
246+ # llama_seq_id ** seq_id;
247+ # int8_t * logits;
248248
249249
250250# // NOTE: helpers for smooth API transition - can be deprecated in the future
@@ -262,7 +262,7 @@ class llama_batch(Structure):
262262 ("token" , POINTER (llama_token )),
263263 ("embd" , c_float_p ),
264264 ("pos" , POINTER (llama_pos )),
265- ("seq_id" , POINTER (llama_seq_id )),
265+ ("seq_id" , POINTER (POINTER ( llama_seq_id ) )),
266266 ("logits" , POINTER (c_int8 )),
267267 ("all_pos_0" , llama_pos ),
268268 ("all_pos_1" , llama_pos ),
@@ -1069,22 +1069,26 @@ def llama_batch_get_one(
10691069_lib .llama_batch_get_one .restype = llama_batch
10701070
10711071
1072- # // Allocates a batch of tokens on the heap
1072+ # // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
1073+ # // Each token can be assigned up to n_seq_max sequence ids
10731074# // The batch has to be freed with llama_batch_free()
10741075# // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
10751076# // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
10761077# // The rest of the llama_batch members are allocated with size n_tokens
10771078# // All members are left uninitialized
10781079# LLAMA_API struct llama_batch llama_batch_init(
10791080# int32_t n_tokens,
1080- # int32_t embd);
1081+ # int32_t embd,
1082+ # int32_t n_seq_max);
10811083def llama_batch_init (
1082- n_tokens : Union [c_int , int ], embd : Union [c_int , int ]
1084+ n_tokens : Union [c_int32 , int ],
1085+ embd : Union [c_int32 , int ],
1086+ n_seq_max : Union [c_int32 , int ],
10831087) -> llama_batch :
1084- return _lib .llama_batch_init (n_tokens , embd )
1088+ return _lib .llama_batch_init (n_tokens , embd , n_seq_max )
10851089
10861090
1087- _lib .llama_batch_init .argtypes = [c_int , c_int ]
1091+ _lib .llama_batch_init .argtypes = [c_int32 , c_int32 , c_int32 ]
10881092_lib .llama_batch_init .restype = llama_batch
10891093
10901094
@@ -1308,6 +1312,46 @@ def llama_tokenize(
13081312_lib .llama_tokenize .restype = c_int
13091313
13101314
1315+ # /// @details Convert the provided text into tokens.
1316+ # /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
1317+ # /// @return Returns the number of tokens on success, no more than n_max_tokens
1318+ # /// @return Returns a negative number on failure - the number of tokens that would have been returned
1319+ # /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
1320+ # /// Does not insert a leading space.
1321+ # LLAMA_API int llama_tokenize(
1322+ # const struct llama_model * model,
1323+ # const char * text,
1324+ # int text_len,
1325+ # llama_token * tokens,
1326+ # int n_max_tokens,
1327+ # bool add_bos,
1328+ # bool special);
1329+ def llama_tokenize (
1330+ model : llama_model_p ,
1331+ text : bytes ,
1332+ text_len : Union [c_int , int ],
1333+ tokens , # type: Array[llama_token]
1334+ n_max_tokens : Union [c_int , int ],
1335+ add_bos : Union [c_bool , bool ],
1336+ special : Union [c_bool , bool ],
1337+ ) -> int :
1338+ return _lib .llama_tokenize (
1339+ model , text , text_len , tokens , n_max_tokens , add_bos , special
1340+ )
1341+
1342+
1343+ _lib .llama_tokenize .argtypes = [
1344+ llama_model_p ,
1345+ c_char_p ,
1346+ c_int ,
1347+ llama_token_p ,
1348+ c_int ,
1349+ c_bool ,
1350+ c_bool ,
1351+ ]
1352+ _lib .llama_tokenize .restype = c_int
1353+
1354+
13111355# // Token Id -> Piece.
13121356# // Uses the vocabulary in the provided context.
13131357# // Does not write null terminator to the buffer.
0 commit comments