@@ -60,6 +60,7 @@ extern "C" {
6060 struct llama_model ;
6161 struct llama_context ;
6262 struct llama_sampler ;
63+ struct llama_kv_cache ;
6364
6465 typedef int32_t llama_pos;
6566 typedef int32_t llama_token;
@@ -470,8 +471,9 @@ extern "C" {
470471
471472 DEPRECATED (LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
472473
473- LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
474- LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
474+ LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
475+ LLAMA_API struct llama_kv_cache * llama_get_kv_cache ( struct llama_context * ctx);
476+ LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
475477
476478 LLAMA_API const struct llama_vocab * llama_model_get_vocab (const struct llama_model * model);
477479 LLAMA_API enum llama_rope_type llama_model_rope_type (const struct llama_model * model);
@@ -586,7 +588,7 @@ extern "C" {
586588 // KV cache
587589 //
588590
589- // TODO: remove llama_kv_cache_view_* API
591+ // TODO: start using struct llama_kv_cache
590592
591593 // Information associated with an individual cell in the KV cache view.
592594 struct llama_kv_cache_view_cell {
@@ -641,41 +643,47 @@ extern "C" {
641643
642644 // Returns the number of tokens in the KV cache (slow, use only for debug)
643645 // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
644- LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx);
646+ LLAMA_API int32_t llama_kv_cache_n_tokens (const struct llama_kv_cache * kv);
647+
648+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx),
649+ "use llama_kv_cache_n_tokens instead");
645650
646651 // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
647- LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx);
652+ LLAMA_API int32_t llama_kv_cache_used_cells (const struct llama_kv_cache * kv);
653+
654+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx),
655+ "use llama_kv_cache_used_cells instead");
648656
649657 // Clear the KV cache - both cell info is erased and KV data is zeroed
650658 LLAMA_API void llama_kv_cache_clear (
651- struct llama_context * ctx );
659+ struct llama_kv_cache * kv );
652660
653661 // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
654662 // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
655663 // seq_id < 0 : match any sequence
656664 // p0 < 0 : [0, p1]
657665 // p1 < 0 : [p0, inf)
658666 LLAMA_API bool llama_kv_cache_seq_rm (
659- struct llama_context * ctx ,
660- llama_seq_id seq_id,
661- llama_pos p0,
662- llama_pos p1);
667+ struct llama_kv_cache * kv ,
668+ llama_seq_id seq_id,
669+ llama_pos p0,
670+ llama_pos p1);
663671
664672 // Copy all tokens that belong to the specified sequence to another sequence
665673 // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
666674 // p0 < 0 : [0, p1]
667675 // p1 < 0 : [p0, inf)
668676 LLAMA_API void llama_kv_cache_seq_cp (
669- struct llama_context * ctx ,
670- llama_seq_id seq_id_src,
671- llama_seq_id seq_id_dst,
672- llama_pos p0,
673- llama_pos p1);
677+ struct llama_kv_cache * kv ,
678+ llama_seq_id seq_id_src,
679+ llama_seq_id seq_id_dst,
680+ llama_pos p0,
681+ llama_pos p1);
674682
675683 // Removes all tokens that do not belong to the specified sequence
676684 LLAMA_API void llama_kv_cache_seq_keep (
677- struct llama_context * ctx ,
678- llama_seq_id seq_id);
685+ struct llama_kv_cache * kv ,
686+ llama_seq_id seq_id);
679687
680688 // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
681689 // If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -684,11 +692,11 @@ extern "C" {
684692 // p0 < 0 : [0, p1]
685693 // p1 < 0 : [p0, inf)
686694 LLAMA_API void llama_kv_cache_seq_add (
687- struct llama_context * ctx ,
688- llama_seq_id seq_id,
689- llama_pos p0,
690- llama_pos p1,
691- llama_pos delta);
695+ struct llama_kv_cache * kv ,
696+ llama_seq_id seq_id,
697+ llama_pos p0,
698+ llama_pos p1,
699+ llama_pos delta);
692700
693701 // Integer division of the positions by factor of `d > 1`
694702 // If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -697,31 +705,28 @@ extern "C" {
697705 // p0 < 0 : [0, p1]
698706 // p1 < 0 : [p0, inf)
699707 LLAMA_API void llama_kv_cache_seq_div (
700- struct llama_context * ctx ,
701- llama_seq_id seq_id,
702- llama_pos p0,
703- llama_pos p1,
704- int d);
708+ struct llama_kv_cache * kv ,
709+ llama_seq_id seq_id,
710+ llama_pos p0,
711+ llama_pos p1,
712+ int d);
705713
706714 // Returns the largest position present in the KV cache for the specified sequence
707715 LLAMA_API llama_pos llama_kv_cache_seq_pos_max (
708- struct llama_context * ctx,
709- llama_seq_id seq_id);
710-
711- // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
712- // how to avoid this?
716+ struct llama_kv_cache * kv,
717+ llama_seq_id seq_id);
713718
714719 // Defragment the KV cache
715720 // This will be applied:
716721 // - lazily on next llama_decode()
717722 // - explicitly with llama_kv_cache_update()
718- LLAMA_API void llama_kv_cache_defrag (struct llama_context * ctx);
719-
720- // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
721- LLAMA_API void llama_kv_cache_update (struct llama_context * ctx);
723+ LLAMA_API void llama_kv_cache_defrag (struct llama_kv_cache * kv);
722724
723725 // Check if the context supports KV cache shifting
724- LLAMA_API bool llama_kv_cache_can_shift (struct llama_context * ctx);
726+ LLAMA_API bool llama_kv_cache_can_shift (const struct llama_kv_cache * kv);
727+
728+ // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
729+ LLAMA_API void llama_update_kv_cache (struct llama_context * ctx, struct llama_kv_cache * kv);
725730
726731 //
727732 // State / sessions
0 commit comments