File tree Expand file tree Collapse file tree 2 files changed +18
-2
lines changed Expand file tree Collapse file tree 2 files changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -1041,7 +1041,7 @@ def embed(
10411041 data : Union [List [List [float ]], List [List [List [float ]]]] = []
10421042
10431043 def decode_batch (seq_sizes : List [int ]):
1044- llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
1044+ llama_cpp .llama_kv_self_clear (self ._ctx .ctx )
10451045 self ._ctx .decode (self ._batch )
10461046 self ._batch .reset ()
10471047
@@ -1112,7 +1112,7 @@ def decode_batch(seq_sizes: List[int]):
11121112
11131113 output = data [0 ] if isinstance (input , str ) else data
11141114
1115- llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
1115+ llama_cpp .llama_kv_self_clear (self ._ctx .ctx )
11161116 self .reset ()
11171117
11181118 if return_count :
Original file line number Diff line number Diff line change @@ -216,3 +216,19 @@ def logit_processor_func(input_ids, logits):
216216
217217 assert number_1 != number_2
218218 assert number_1 == number_3
219+
220+
221+ def test_real_llama_embeddings (llama_cpp_model_path ):
222+ model = llama_cpp .Llama (
223+ llama_cpp_model_path ,
224+ n_ctx = 32 ,
225+ n_batch = 32 ,
226+ n_ubatch = 32 ,
227+ n_threads = multiprocessing .cpu_count (),
228+ n_threads_batch = multiprocessing .cpu_count (),
229+ logits_all = False ,
230+ flash_attn = True ,
231+ embedding = True
232+ )
233+ # Smoke test for now
234+ model .embed ("Hello World" )
You can’t perform that action at this time.
0 commit comments