Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 79 additions & 45 deletions sqlite-lembed.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "sqlite-lembed.h"
#include "llama.h"
#include <llama.h>
#include <assert.h>
#include <math.h>
#include <stdlib.h>
Expand Down Expand Up @@ -28,28 +28,69 @@ static void normalize(float *vec, float *out, int n) {

#define LEMBED_TOKEN_SUBTYPE 116 // ascii 't'

int tokenize(struct llama_model *model, const char *input, size_t input_length,
int *token_count, llama_token **tokens) {
int input_token_count_estimate =
llama_tokenize(model, input, input_length, NULL, 0, true, true);
if (input_token_count_estimate >= 0) {
return SQLITE_ERROR;
}
*tokens =
sqlite3_malloc(sizeof(llama_token) * abs(input_token_count_estimate));
if (!(*tokens)) {
return SQLITE_NOMEM;
}
int input_token_count =
llama_tokenize(model, input, input_length, *tokens,
abs(input_token_count_estimate), true, true);
if (input_token_count != abs(input_token_count_estimate)) {
sqlite3_free(*tokens);
return SQLITE_ERROR;
}
int tokenize(struct llama_model* model, const char* input, size_t input_length,
int* token_count, llama_token** tokens) {

*token_count = input_token_count;
return SQLITE_OK;
if (!model || !input || !token_count || !tokens) {
return SQLITE_ERROR;
}

*tokens = NULL;
*token_count = 0;

int required_buffer_size_or_error = llama_tokenize(
llama_model_get_vocab(model),
input,
(int)input_length,
NULL,
0,
true,
true
);

int n_tokens;
if (required_buffer_size_or_error >= 0) {
if (required_buffer_size_or_error == 0) {
n_tokens = 0;
}
else {
return SQLITE_ERROR;
}
}
else {
n_tokens = -required_buffer_size_or_error;
}

if (n_tokens == 0) {
*tokens = NULL;
*token_count = 0;
return SQLITE_OK;
}

*tokens = (llama_token*)sqlite3_malloc(sizeof(llama_token) * n_tokens);
if (!(*tokens)) {
return SQLITE_NOMEM;
}

int actual_tokens_written = llama_tokenize(
llama_model_get_vocab(model),
input,
(int)input_length,
*tokens,
n_tokens,
true,
true
);

if (actual_tokens_written < 0 || actual_tokens_written != n_tokens) {
sqlite3_free(*tokens);
*tokens = NULL;
*token_count = 0;
return SQLITE_ERROR;
}

*token_count = actual_tokens_written;
return SQLITE_OK;
}

int embed_single(struct llama_model *model, struct llama_context *context,
Expand Down Expand Up @@ -92,7 +133,7 @@ int embed_single(struct llama_model *model, struct llama_context *context,
return SQLITE_NOMEM;
}

llama_kv_cache_clear(context); // KV not needed for embeddings?
llama_kv_self_clear(context); // KV not needed for embeddings?
rc = llama_decode(context, batch);
if(rc != 0) {
sqlite3_free(output_embedding);
Expand Down Expand Up @@ -183,12 +224,11 @@ static void lembed_model_options_(sqlite3_context *context, int argc,

typedef struct lembed_context_options lembed_context_options;
struct lembed_context_options {
uint32_t seed;
uint32_t n_ctx;
enum llama_rope_scaling_type rope_scaling_type;
float rope_freq_scale;

int8_t defined[4];
int8_t defined[3];
};
static char *POINTER_NAME_CONTEXT_OPTIONS = "lembed_context_options";

Expand All @@ -205,16 +245,11 @@ static void lembed_context_options_(sqlite3_context *context, int argc,
sqlite3_value *value = argv[i + 1];
assert(sqlite3_value_type(key) == SQLITE_TEXT);
const char *k = (const char *)sqlite3_value_text(key);
if (sqlite3_stricmp("seed", k) == 0) {
sqlite3_int64 v = sqlite3_value_int64(value);
assert(v > 0);
o->seed = v;
o->defined[0] = 1;
} else if (sqlite3_stricmp("n_ctx", k) == 0) {
if (sqlite3_stricmp("n_ctx", k) == 0) {
sqlite3_int64 v = sqlite3_value_int64(value);
assert(v > 0);
o->n_ctx = v;
o->defined[1] = 1;
o->defined[0] = 1;
} else if (sqlite3_stricmp("rope_scaling_type", k) == 0) {
const char *v = (const char *)sqlite3_value_text(value);
if (sqlite3_stricmp(v, "none")) {
Expand All @@ -227,10 +262,10 @@ static void lembed_context_options_(sqlite3_context *context, int argc,
abort();
}

o->defined[2] = 1;
o->defined[1] = 1;
} else if (sqlite3_stricmp(k, "rope_freq_scale") == 0) {
o->rope_freq_scale = sqlite3_value_double(value);
o->defined[3] = 1;
o->defined[2] = 1;
} else {
abort();
}
Expand Down Expand Up @@ -360,7 +395,7 @@ static void lembed_token_to_piece_(sqlite3_context *context, int argc,
int32_t token = sqlite3_value_int(argv[1]);
#define BUFLEN 256
char buf[BUFLEN];
int n = llama_token_to_piece(model, token, buf, BUFLEN, false);
int n = llama_token_to_piece(model, token, buf, BUFLEN, 0, false);
if (n) {
sqlite3_result_text(context, buf, n, SQLITE_TRANSIENT);
} else {
Expand Down Expand Up @@ -466,11 +501,13 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc,

struct llama_model *model;
struct llama_model_params mparams = llama_model_default_params();


if (modelOptions && modelOptions->defined[0]) {
mparams.n_gpu_layers = modelOptions->n_gpu_layers;
}

model = llama_load_model_from_file(modelPath, mparams);
model = llama_model_load_from_file(modelPath, mparams);
if (!model) {
return SQLITE_ERROR;
}
Expand All @@ -480,22 +517,19 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc,
cparams.embeddings = 1;
if (contextOptions) {
if (contextOptions->defined[0]) {
cparams.seed = contextOptions->seed;
}
if (contextOptions->defined[1]) {
cparams.n_ctx = contextOptions->n_ctx;
}
if (contextOptions->defined[2]) {
if (contextOptions->defined[1]) {
cparams.rope_scaling_type = contextOptions->rope_scaling_type;
}
if (contextOptions->defined[3]) {
if (contextOptions->defined[2]) {
cparams.rope_freq_scale = contextOptions->rope_freq_scale;
}
}

ctx = llama_new_context_with_model(model, cparams);
ctx = llama_init_from_model(model, cparams);
if (!ctx) {
llama_free_model(model);
llama_model_free(model);
return SQLITE_ERROR;
}
p->api->models[idx].model = model;
Expand Down Expand Up @@ -742,7 +776,7 @@ static int lembed_chunksFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
for (int j = 0; j < chunk_size; j++) {
int32_t token = tokens[i * chunk_size + j];
int32_t piece_len_neg =
llama_token_to_piece(model, token, NULL, 0, false);
llama_token_to_piece(model, token, NULL, 0, 0, false);
// printf("%d\n", piece_len_neg);
// assert(piece_len_neg < 0);
int32_t piece_len = abs(piece_len_neg);
Expand All @@ -753,7 +787,7 @@ static int lembed_chunksFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,

char *piece = sqlite3_malloc(piece_len);
assert(piece);
llama_token_to_piece(model, token, piece, piece_len, false);
llama_token_to_piece(model, token, piece, piece_len, 0, false);
// printf("'%.*s' %d ", piece_len, piece, tokens[i*chunk_size + j]);

char *begin = ptr;
Expand Down