Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions sqlite-lembed.c
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,11 @@ static void lembed_model_options_(sqlite3_context *context, int argc,

typedef struct lembed_context_options lembed_context_options;
struct lembed_context_options {
uint32_t seed;
uint32_t n_ctx;
enum llama_rope_scaling_type rope_scaling_type;
float rope_freq_scale;

int8_t defined[4];
int8_t defined[3];
};
static char *POINTER_NAME_CONTEXT_OPTIONS = "lembed_context_options";

Expand All @@ -205,16 +204,11 @@ static void lembed_context_options_(sqlite3_context *context, int argc,
sqlite3_value *value = argv[i + 1];
assert(sqlite3_value_type(key) == SQLITE_TEXT);
const char *k = (const char *)sqlite3_value_text(key);
if (sqlite3_stricmp("seed", k) == 0) {
sqlite3_int64 v = sqlite3_value_int64(value);
assert(v > 0);
o->seed = v;
o->defined[0] = 1;
} else if (sqlite3_stricmp("n_ctx", k) == 0) {
if (sqlite3_stricmp("n_ctx", k) == 0) {
sqlite3_int64 v = sqlite3_value_int64(value);
assert(v > 0);
o->n_ctx = v;
o->defined[1] = 1;
o->defined[0] = 1;
} else if (sqlite3_stricmp("rope_scaling_type", k) == 0) {
const char *v = (const char *)sqlite3_value_text(value);
if (sqlite3_stricmp(v, "none")) {
Expand All @@ -227,10 +221,10 @@ static void lembed_context_options_(sqlite3_context *context, int argc,
abort();
}

o->defined[2] = 1;
o->defined[1] = 1;
} else if (sqlite3_stricmp(k, "rope_freq_scale") == 0) {
o->rope_freq_scale = sqlite3_value_double(value);
o->defined[3] = 1;
o->defined[2] = 1;
} else {
abort();
}
Expand Down Expand Up @@ -360,7 +354,7 @@ static void lembed_token_to_piece_(sqlite3_context *context, int argc,
int32_t token = sqlite3_value_int(argv[1]);
#define BUFLEN 256
char buf[BUFLEN];
int n = llama_token_to_piece(model, token, buf, BUFLEN, false);
int n = llama_token_to_piece(model, token, buf, BUFLEN, 0, false);
if (n) {
sqlite3_result_text(context, buf, n, SQLITE_TRANSIENT);
} else {
Expand Down Expand Up @@ -480,15 +474,12 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc,
cparams.embeddings = 1;
if (contextOptions) {
if (contextOptions->defined[0]) {
cparams.seed = contextOptions->seed;
}
if (contextOptions->defined[1]) {
cparams.n_ctx = contextOptions->n_ctx;
}
if (contextOptions->defined[2]) {
if (contextOptions->defined[1]) {
cparams.rope_scaling_type = contextOptions->rope_scaling_type;
}
if (contextOptions->defined[3]) {
if (contextOptions->defined[2]) {
cparams.rope_freq_scale = contextOptions->rope_freq_scale;
}
}
Expand Down Expand Up @@ -742,7 +733,7 @@ static int lembed_chunksFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,
for (int j = 0; j < chunk_size; j++) {
int32_t token = tokens[i * chunk_size + j];
int32_t piece_len_neg =
llama_token_to_piece(model, token, NULL, 0, false);
llama_token_to_piece(model, token, NULL, 0, 0, false);
// printf("%d\n", piece_len_neg);
// assert(piece_len_neg < 0);
int32_t piece_len = abs(piece_len_neg);
Expand All @@ -753,7 +744,7 @@ static int lembed_chunksFilter(sqlite3_vtab_cursor *pVtabCursor, int idxNum,

char *piece = sqlite3_malloc(piece_len);
assert(piece);
llama_token_to_piece(model, token, piece, piece_len, false);
llama_token_to_piece(model, token, piece, piece_len, 0, false);
// printf("'%.*s' %d ", piece_len, piece, tokens[i*chunk_size + j]);

char *begin = ptr;
Expand Down