diff --git a/sqlite-lembed.c b/sqlite-lembed.c index 479a554..d57719f 100644 --- a/sqlite-lembed.c +++ b/sqlite-lembed.c @@ -183,12 +183,11 @@ static void lembed_model_options_(sqlite3_context *context, int argc, typedef struct lembed_context_options lembed_context_options; struct lembed_context_options { - uint32_t seed; uint32_t n_ctx; enum llama_rope_scaling_type rope_scaling_type; float rope_freq_scale; - int8_t defined[4]; + int8_t defined[3]; }; static char *POINTER_NAME_CONTEXT_OPTIONS = "lembed_context_options"; @@ -205,16 +204,11 @@ static void lembed_context_options_(sqlite3_context *context, int argc, sqlite3_value *value = argv[i + 1]; assert(sqlite3_value_type(key) == SQLITE_TEXT); const char *k = (const char *)sqlite3_value_text(key); - if (sqlite3_stricmp("seed", k) == 0) { - sqlite3_int64 v = sqlite3_value_int64(value); - assert(v > 0); - o->seed = v; - o->defined[0] = 1; - } else if (sqlite3_stricmp("n_ctx", k) == 0) { + if (sqlite3_stricmp("n_ctx", k) == 0) { sqlite3_int64 v = sqlite3_value_int64(value); assert(v > 0); o->n_ctx = v; - o->defined[1] = 1; + o->defined[0] = 1; } else if (sqlite3_stricmp("rope_scaling_type", k) == 0) { const char *v = (const char *)sqlite3_value_text(value); if (sqlite3_stricmp(v, "none")) { @@ -227,10 +221,10 @@ static void lembed_context_options_(sqlite3_context *context, int argc, abort(); } - o->defined[2] = 1; + o->defined[1] = 1; } else if (sqlite3_stricmp(k, "rope_freq_scale") == 0) { o->rope_freq_scale = sqlite3_value_double(value); - o->defined[3] = 1; + o->defined[2] = 1; } else { abort(); } @@ -360,7 +354,7 @@ static void lembed_token_to_piece_(sqlite3_context *context, int argc, int32_t token = sqlite3_value_int(argv[1]); #define BUFLEN 256 char buf[BUFLEN]; - int n = llama_token_to_piece(model, token, buf, BUFLEN, false); + int n = llama_token_to_piece(model, token, buf, BUFLEN, 0, false); if (n) { sqlite3_result_text(context, buf, n, SQLITE_TRANSIENT); } else { @@ -480,15 +474,12 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc, cparams.embeddings = 1; if (contextOptions) { if (contextOptions->defined[0]) { - cparams.seed = contextOptions->seed; - } - if (contextOptions->defined[1]) { cparams.n_ctx = contextOptions->n_ctx; } - if (contextOptions->defined[2]) { + if (contextOptions->defined[1]) { cparams.rope_scaling_type = contextOptions->rope_scaling_type; } - if (contextOptions->defined[3]) { + if (contextOptions->defined[2]) { cparams.rope_freq_scale = contextOptions->rope_freq_scale; } } @@ -742,7 +733,7 @@ static int lembed_chunksFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, for (int j = 0; j < chunk_size; j++) { int32_t token = tokens[i * chunk_size + j]; int32_t piece_len_neg = - llama_token_to_piece(model, token, NULL, 0, false); + llama_token_to_piece(model, token, NULL, 0, 0, false); // printf("%d\n", piece_len_neg); // assert(piece_len_neg < 0); int32_t piece_len = abs(piece_len_neg); @@ -753,7 +744,7 @@ static int lembed_chunksFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum, char *piece = sqlite3_malloc(piece_len); assert(piece); - llama_token_to_piece(model, token, piece, piece_len, false); + llama_token_to_piece(model, token, piece, piece_len, 0, false); // printf("'%.*s' %d ", piece_len, piece, tokens[i*chunk_size + j]); char *begin = ptr;