diff --git a/sqlite-lembed.c b/sqlite-lembed.c index 479a554..87ee143 100644 --- a/sqlite-lembed.c +++ b/sqlite-lembed.c @@ -1,5 +1,5 @@ #include "sqlite-lembed.h" -#include "llama.h" +#include #include #include #include @@ -28,28 +28,69 @@ static void normalize(float *vec, float *out, int n) { #define LEMBED_TOKEN_SUBTYPE 116 // ascii 't' -int tokenize(struct llama_model *model, const char *input, size_t input_length, - int *token_count, llama_token **tokens) { - int input_token_count_estimate = - llama_tokenize(model, input, input_length, NULL, 0, true, true); - if (input_token_count_estimate >= 0) { - return SQLITE_ERROR; - } - *tokens = - sqlite3_malloc(sizeof(llama_token) * abs(input_token_count_estimate)); - if (!(*tokens)) { - return SQLITE_NOMEM; - } - int input_token_count = - llama_tokenize(model, input, input_length, *tokens, - abs(input_token_count_estimate), true, true); - if (input_token_count != abs(input_token_count_estimate)) { - sqlite3_free(*tokens); - return SQLITE_ERROR; - } +int tokenize(struct llama_model* model, const char* input, size_t input_length, + int* token_count, llama_token** tokens) { - *token_count = input_token_count; - return SQLITE_OK; + if (!model || !input || !token_count || !tokens) { + return SQLITE_ERROR; + } + + *tokens = NULL; + *token_count = 0; + + int required_buffer_size_or_error = llama_tokenize( + llama_model_get_vocab(model), + input, + (int)input_length, + NULL, + 0, + true, + true + ); + + int n_tokens; + if (required_buffer_size_or_error >= 0) { + if (required_buffer_size_or_error == 0) { + n_tokens = 0; + } + else { + return SQLITE_ERROR; + } + } + else { + n_tokens = -required_buffer_size_or_error; + } + + if (n_tokens == 0) { + *tokens = NULL; + *token_count = 0; + return SQLITE_OK; + } + + *tokens = (llama_token*)sqlite3_malloc(sizeof(llama_token) * n_tokens); + if (!(*tokens)) { + return SQLITE_NOMEM; + } + + int actual_tokens_written = llama_tokenize( + llama_model_get_vocab(model), + input, + (int)input_length, + *tokens, + n_tokens, + true, + true + ); + + if (actual_tokens_written < 0 || actual_tokens_written != n_tokens) { + sqlite3_free(*tokens); + *tokens = NULL; + *token_count = 0; + return SQLITE_ERROR; + } + + *token_count = actual_tokens_written; + return SQLITE_OK; } int embed_single(struct llama_model *model, struct llama_context *context, @@ -92,7 +133,7 @@ int embed_single(struct llama_model *model, struct llama_context *context, return SQLITE_NOMEM; } - llama_kv_cache_clear(context); // KV not needed for embeddings? + llama_kv_self_clear(context); // KV not needed for embeddings? rc = llama_decode(context, batch); if(rc != 0) { sqlite3_free(output_embedding); @@ -183,12 +224,11 @@ static void lembed_model_options_(sqlite3_context *context, int argc, typedef struct lembed_context_options lembed_context_options; struct lembed_context_options { - uint32_t seed; uint32_t n_ctx; enum llama_rope_scaling_type rope_scaling_type; float rope_freq_scale; - int8_t defined[4]; + int8_t defined[3]; }; static char *POINTER_NAME_CONTEXT_OPTIONS = "lembed_context_options"; @@ -205,16 +245,11 @@ static void lembed_context_options_(sqlite3_context *context, int argc, sqlite3_value *value = argv[i + 1]; assert(sqlite3_value_type(key) == SQLITE_TEXT); const char *k = (const char *)sqlite3_value_text(key); - if (sqlite3_stricmp("seed", k) == 0) { - sqlite3_int64 v = sqlite3_value_int64(value); - assert(v > 0); - o->seed = v; - o->defined[0] = 1; - } else if (sqlite3_stricmp("n_ctx", k) == 0) { + if (sqlite3_stricmp("n_ctx", k) == 0) { sqlite3_int64 v = sqlite3_value_int64(value); assert(v > 0); o->n_ctx = v; - o->defined[1] = 1; + o->defined[0] = 1; } else if (sqlite3_stricmp("rope_scaling_type", k) == 0) { const char *v = (const char *)sqlite3_value_text(value); if (sqlite3_stricmp(v, "none")) { @@ -227,10 +262,10 @@ static void lembed_context_options_(sqlite3_context *context, int argc, abort(); } - o->defined[2] = 1; + o->defined[1] = 1; } else if (sqlite3_stricmp(k, "rope_freq_scale") == 0) { o->rope_freq_scale = sqlite3_value_double(value); - o->defined[3] = 1; + o->defined[2] = 1; } else { abort(); } @@ -360,7 +395,7 @@ static void lembed_token_to_piece_(sqlite3_context *context, int argc, int32_t token = sqlite3_value_int(argv[1]); #define BUFLEN 256 char buf[BUFLEN]; - int n = llama_token_to_piece(model, token, buf, BUFLEN, false); + int n = llama_token_to_piece(model, token, buf, BUFLEN, 0, false); if (n) { sqlite3_result_text(context, buf, n, SQLITE_TRANSIENT); } else { @@ -466,11 +501,13 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc, struct llama_model *model; struct llama_model_params mparams = llama_model_default_params(); + + if (modelOptions && modelOptions->defined[0]) { mparams.n_gpu_layers = modelOptions->n_gpu_layers; } - model = llama_load_model_from_file(modelPath, mparams); + model = llama_model_load_from_file(modelPath, mparams); if (!model) { return SQLITE_ERROR; } @@ -480,22 +517,19 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc, cparams.embeddings = 1; if (contextOptions) { if (contextOptions->defined[0]) { - cparams.seed = contextOptions->seed; - } - if (contextOptions->defined[1]) { cparams.n_ctx = contextOptions->n_ctx; } - if (contextOptions->defined[2]) { + if (contextOptions->defined[1]) { cparams.rope_scaling_type = contextOptions->rope_scaling_type; } - if (contextOptions->defined[3]) { + if (contextOptions->defined[2]) { cparams.rope_freq_scale = contextOptions->rope_freq_scale; } } - ctx = llama_new_context_with_model(model, cparams); + ctx = llama_init_from_model(model, cparams); if (!ctx) { - llama_free_model(model); + llama_model_free(model); return SQLITE_ERROR; } p->api->models[idx].model = model; @@ -742,7 +776,7 @@ static int lembed_chunksFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, for (int j = 0; j < chunk_size; j++) { int32_t token = tokens[i * chunk_size + j]; int32_t piece_len_neg = - llama_token_to_piece(model, token, NULL, 0, false); + llama_token_to_piece(model, token, NULL, 0, 0, false); // printf("%d\n", piece_len_neg); // assert(piece_len_neg < 0); int32_t piece_len = abs(piece_len_neg); @@ -753,7 +787,7 @@ static int lembed_chunksFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, char *piece = sqlite3_malloc(piece_len); assert(piece); - llama_token_to_piece(model, token, piece, piece_len, false); + llama_token_to_piece(model, token, piece, piece_len, 0, false); // printf("'%.*s' %d ", piece_len, piece, tokens[i*chunk_size + j]); char *begin = ptr;