Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions sqlite-lembed.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ int embed_single(struct llama_model *model, struct llama_context *context,
float **out_embedding,
/** Output embedding length (n dimensions) */
int *out_dimensions) {
int n_batch = 512;
int n_ctx_train = llama_n_ctx_train(model);
int n_ctx = llama_n_ctx(context);

llama_token *tokens;
Expand All @@ -70,6 +68,12 @@ int embed_single(struct llama_model *model, struct llama_context *context,
return rc;
}

int n_batch = llama_n_batch(context);
if (token_count > n_batch) {
// Truncate silently.
// TODO: explode w/ an error unless user passes allow_truncate=true?
token_count = n_batch;
}
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);

int seq_id = 0;
Expand Down Expand Up @@ -478,6 +482,7 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc,
struct llama_context *ctx;
struct llama_context_params cparams = llama_context_default_params();
cparams.embeddings = 1;
cparams.n_ctx = llama_n_ctx_train(model);
if (contextOptions) {
if (contextOptions->defined[0]) {
cparams.seed = contextOptions->seed;
Expand All @@ -492,6 +497,8 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc,
cparams.rope_freq_scale = contextOptions->rope_freq_scale;
}
}
cparams.n_ubatch = cparams.n_ctx;
cparams.n_batch = cparams.n_ctx;

ctx = llama_new_context_with_model(model, cparams);
if (!ctx) {
Expand Down
3 changes: 3 additions & 0 deletions tests/test-loadable.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def test_lembed():
assert struct.unpack("1f", a[0:4])[0] == pytest.approx(
-0.09205757826566696, rel=1e-2
)

# test input larger than default 512 tokens
lembed('ab' * 1000)


@pytest.mark.skip(reason="TODO")
Expand Down