@@ -60,6 +60,7 @@ extern "C" {
6060 struct llama_model ;
6161 struct llama_context ;
6262 struct llama_sampler ;
63+ struct llama_kv_cache ;
6364
6465 typedef int32_t llama_pos;
6566 typedef int32_t llama_token;
@@ -467,8 +468,9 @@ extern "C" {
467468
468469 DEPRECATED (LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
469470
470- LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
471- LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
471+ LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
472+ LLAMA_API struct llama_kv_cache * llama_get_kv_cache ( struct llama_context * ctx);
473+ LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
472474
473475 LLAMA_API const struct llama_vocab * llama_model_get_vocab (const struct llama_model * model);
474476 LLAMA_API enum llama_rope_type llama_model_rope_type (const struct llama_model * model);
@@ -583,7 +585,7 @@ extern "C" {
583585 // KV cache
584586 //
585587
586- // TODO: remove llama_kv_cache_view_* API
588+ // TODO: start using struct llama_kv_cache
587589
588590 // Information associated with an individual cell in the KV cache view.
589591 struct llama_kv_cache_view_cell {
@@ -638,41 +640,47 @@ extern "C" {
638640
639641 // Returns the number of tokens in the KV cache (slow, use only for debug)
640642 // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
641- LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx);
643+ LLAMA_API int32_t llama_kv_cache_n_tokens (const struct llama_kv_cache * kv);
644+
645+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx),
646+ "use llama_kv_cache_n_tokens instead");
642647
643648 // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
644- LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx);
649+ LLAMA_API int32_t llama_kv_cache_used_cells (const struct llama_kv_cache * kv);
650+
651+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx),
652+ "use llama_kv_cache_used_cells instead");
645653
646654 // Clear the KV cache - both cell info is erased and KV data is zeroed
647655 LLAMA_API void llama_kv_cache_clear (
648- struct llama_context * ctx );
656+ struct llama_kv_cache * kv );
649657
650658 // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
651659 // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
652660 // seq_id < 0 : match any sequence
653661 // p0 < 0 : [0, p1]
654662 // p1 < 0 : [p0, inf)
655663 LLAMA_API bool llama_kv_cache_seq_rm (
656- struct llama_context * ctx ,
657- llama_seq_id seq_id,
658- llama_pos p0,
659- llama_pos p1);
664+ struct llama_kv_cache * kv ,
665+ llama_seq_id seq_id,
666+ llama_pos p0,
667+ llama_pos p1);
660668
661669 // Copy all tokens that belong to the specified sequence to another sequence
662670 // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
663671 // p0 < 0 : [0, p1]
664672 // p1 < 0 : [p0, inf)
665673 LLAMA_API void llama_kv_cache_seq_cp (
666- struct llama_context * ctx ,
667- llama_seq_id seq_id_src,
668- llama_seq_id seq_id_dst,
669- llama_pos p0,
670- llama_pos p1);
674+ struct llama_kv_cache * kv ,
675+ llama_seq_id seq_id_src,
676+ llama_seq_id seq_id_dst,
677+ llama_pos p0,
678+ llama_pos p1);
671679
672680 // Removes all tokens that do not belong to the specified sequence
673681 LLAMA_API void llama_kv_cache_seq_keep (
674- struct llama_context * ctx ,
675- llama_seq_id seq_id);
682+ struct llama_kv_cache * kv ,
683+ llama_seq_id seq_id);
676684
677685 // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
678686 // If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -681,11 +689,11 @@ extern "C" {
681689 // p0 < 0 : [0, p1]
682690 // p1 < 0 : [p0, inf)
683691 LLAMA_API void llama_kv_cache_seq_add (
684- struct llama_context * ctx ,
685- llama_seq_id seq_id,
686- llama_pos p0,
687- llama_pos p1,
688- llama_pos delta);
692+ struct llama_kv_cache * kv ,
693+ llama_seq_id seq_id,
694+ llama_pos p0,
695+ llama_pos p1,
696+ llama_pos delta);
689697
690698 // Integer division of the positions by factor of `d > 1`
691699 // If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -694,31 +702,28 @@ extern "C" {
694702 // p0 < 0 : [0, p1]
695703 // p1 < 0 : [p0, inf)
696704 LLAMA_API void llama_kv_cache_seq_div (
697- struct llama_context * ctx ,
698- llama_seq_id seq_id,
699- llama_pos p0,
700- llama_pos p1,
701- int d);
705+ struct llama_kv_cache * kv ,
706+ llama_seq_id seq_id,
707+ llama_pos p0,
708+ llama_pos p1,
709+ int d);
702710
703711 // Returns the largest position present in the KV cache for the specified sequence
704712 LLAMA_API llama_pos llama_kv_cache_seq_pos_max (
705- struct llama_context * ctx,
706- llama_seq_id seq_id);
707-
708- // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
709- // how to avoid this?
713+ struct llama_kv_cache * kv,
714+ llama_seq_id seq_id);
710715
711716 // Defragment the KV cache
712717 // This will be applied:
713718 // - lazily on next llama_decode()
714719 // - explicitly with llama_kv_cache_update()
715- LLAMA_API void llama_kv_cache_defrag (struct llama_context * ctx);
716-
717- // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
718- LLAMA_API void llama_kv_cache_update (struct llama_context * ctx);
720+ LLAMA_API void llama_kv_cache_defrag (struct llama_kv_cache * kv);
719721
720722 // Check if the context supports KV cache shifting
721- LLAMA_API bool llama_kv_cache_can_shift (struct llama_context * ctx);
723+ LLAMA_API bool llama_kv_cache_can_shift (const struct llama_kv_cache * kv);
724+
725+ // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
726+ LLAMA_API void llama_update_kv_cache (struct llama_context * ctx, struct llama_kv_cache * kv);
722727
723728 //
724729 // State / sessions
0 commit comments