diff --git a/sqlite-lembed.c b/sqlite-lembed.c index 479a554..7c3caf9 100644 --- a/sqlite-lembed.c +++ b/sqlite-lembed.c @@ -58,8 +58,6 @@ int embed_single(struct llama_model *model, struct llama_context *context, float **out_embedding, /** Output embedding length (n dimensions) */ int *out_dimensions) { - int n_batch = 512; - int n_ctx_train = llama_n_ctx_train(model); int n_ctx = llama_n_ctx(context); llama_token *tokens; @@ -70,6 +68,12 @@ int embed_single(struct llama_model *model, struct llama_context *context, return rc; } + int n_batch = llama_n_batch(context); + if (token_count > n_batch) { + // Truncate silently. + // TODO: explode w/ an error unless user passes allow_truncate=true? + token_count = n_batch; + } struct llama_batch batch = llama_batch_init(n_batch, 0, 1); int seq_id = 0; @@ -478,6 +482,7 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc, struct llama_context *ctx; struct llama_context_params cparams = llama_context_default_params(); cparams.embeddings = 1; + cparams.n_ctx = llama_n_ctx_train(model); if (contextOptions) { if (contextOptions->defined[0]) { cparams.seed = contextOptions->seed; @@ -492,6 +497,8 @@ static int lembed_modelsUpdate(sqlite3_vtab *pVTab, int argc, cparams.rope_freq_scale = contextOptions->rope_freq_scale; } } + cparams.n_ubatch = cparams.n_ctx; + cparams.n_batch = cparams.n_ctx; ctx = llama_new_context_with_model(model, cparams); if (!ctx) { diff --git a/tests/test-loadable.py b/tests/test-loadable.py index bf1166c..2b1800d 100644 --- a/tests/test-loadable.py +++ b/tests/test-loadable.py @@ -140,6 +140,9 @@ def test_lembed(): assert struct.unpack("1f", a[0:4])[0] == pytest.approx( -0.09205757826566696, rel=1e-2 ) + + # test input larger than default 512 tokens + lembed('ab' * 1000) @pytest.mark.skip(reason="TODO")