@@ -60,6 +60,7 @@ extern "C" {
6060 struct llama_model ;
6161 struct llama_context ;
6262 struct llama_sampler ;
63+ struct llama_kv_cache ;
6364
6465 typedef int32_t llama_pos;
6566 typedef int32_t llama_token;
@@ -460,8 +461,9 @@ extern "C" {
460461
461462 DEPRECATED (LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
462463
463- LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
464- LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
464+ LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
465+ LLAMA_API struct llama_kv_cache * llama_get_kv_cache ( struct llama_context * ctx);
466+ LLAMA_API enum llama_pooling_type llama_pooling_type (const struct llama_context * ctx);
465467
466468 LLAMA_API const struct llama_vocab * llama_model_get_vocab (const struct llama_model * model);
467469 LLAMA_API enum llama_rope_type llama_model_rope_type (const struct llama_model * model);
@@ -576,7 +578,7 @@ extern "C" {
576578 // KV cache
577579 //
578580
579- // TODO: remove llama_kv_cache_view_* API
581+ // TODO: start using struct llama_kv_cache
580582
581583 // Information associated with an individual cell in the KV cache view.
582584 struct llama_kv_cache_view_cell {
@@ -631,41 +633,47 @@ extern "C" {
631633
632634 // Returns the number of tokens in the KV cache (slow, use only for debug)
633635 // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
634- LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx);
636+ LLAMA_API int32_t llama_kv_cache_n_tokens (const struct llama_kv_cache * kv);
637+
638+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_token_count (const struct llama_context * ctx),
639+ "use llama_kv_cache_n_tokens instead");
635640
636641 // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
637- LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx);
642+ LLAMA_API int32_t llama_kv_cache_used_cells (const struct llama_kv_cache * kv);
643+
644+ DEPRECATED (LLAMA_API int32_t llama_get_kv_cache_used_cells (const struct llama_context * ctx),
645+ "use llama_kv_cache_used_cells instead");
638646
639647 // Clear the KV cache - both cell info is erased and KV data is zeroed
640648 LLAMA_API void llama_kv_cache_clear (
641- struct llama_context * ctx );
649+ struct llama_kv_cache * kv );
642650
643651 // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
644652 // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
645653 // seq_id < 0 : match any sequence
646654 // p0 < 0 : [0, p1]
647655 // p1 < 0 : [p0, inf)
648656 LLAMA_API bool llama_kv_cache_seq_rm (
649- struct llama_context * ctx ,
650- llama_seq_id seq_id,
651- llama_pos p0,
652- llama_pos p1);
657+ struct llama_kv_cache * kv ,
658+ llama_seq_id seq_id,
659+ llama_pos p0,
660+ llama_pos p1);
653661
654662 // Copy all tokens that belong to the specified sequence to another sequence
655663 // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
656664 // p0 < 0 : [0, p1]
657665 // p1 < 0 : [p0, inf)
658666 LLAMA_API void llama_kv_cache_seq_cp (
659- struct llama_context * ctx ,
660- llama_seq_id seq_id_src,
661- llama_seq_id seq_id_dst,
662- llama_pos p0,
663- llama_pos p1);
667+ struct llama_kv_cache * kv ,
668+ llama_seq_id seq_id_src,
669+ llama_seq_id seq_id_dst,
670+ llama_pos p0,
671+ llama_pos p1);
664672
665673 // Removes all tokens that do not belong to the specified sequence
666674 LLAMA_API void llama_kv_cache_seq_keep (
667- struct llama_context * ctx ,
668- llama_seq_id seq_id);
675+ struct llama_kv_cache * kv ,
676+ llama_seq_id seq_id);
669677
670678 // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
671679 // If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -674,11 +682,11 @@ extern "C" {
674682 // p0 < 0 : [0, p1]
675683 // p1 < 0 : [p0, inf)
676684 LLAMA_API void llama_kv_cache_seq_add (
677- struct llama_context * ctx ,
678- llama_seq_id seq_id,
679- llama_pos p0,
680- llama_pos p1,
681- llama_pos delta);
685+ struct llama_kv_cache * kv ,
686+ llama_seq_id seq_id,
687+ llama_pos p0,
688+ llama_pos p1,
689+ llama_pos delta);
682690
683691 // Integer division of the positions by factor of `d > 1`
684692 // If the KV cache is RoPEd, the KV data is updated accordingly:
@@ -687,31 +695,28 @@ extern "C" {
687695 // p0 < 0 : [0, p1]
688696 // p1 < 0 : [p0, inf)
689697 LLAMA_API void llama_kv_cache_seq_div (
690- struct llama_context * ctx ,
691- llama_seq_id seq_id,
692- llama_pos p0,
693- llama_pos p1,
694- int d);
698+ struct llama_kv_cache * kv ,
699+ llama_seq_id seq_id,
700+ llama_pos p0,
701+ llama_pos p1,
702+ int d);
695703
696704 // Returns the largest position present in the KV cache for the specified sequence
697705 LLAMA_API llama_pos llama_kv_cache_seq_pos_max (
698- struct llama_context * ctx,
699- llama_seq_id seq_id);
700-
701- // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache
702- // how to avoid this?
706+ struct llama_kv_cache * kv,
707+ llama_seq_id seq_id);
703708
704709 // Defragment the KV cache
705710 // This will be applied:
706711 // - lazily on next llama_decode()
707712 // - explicitly with llama_kv_cache_update()
708- LLAMA_API void llama_kv_cache_defrag (struct llama_context * ctx);
709-
710- // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
711- LLAMA_API void llama_kv_cache_update (struct llama_context * ctx);
713+ LLAMA_API void llama_kv_cache_defrag (struct llama_kv_cache * kv);
712714
713715 // Check if the context supports KV cache shifting
714- LLAMA_API bool llama_kv_cache_can_shift (struct llama_context * ctx);
716+ LLAMA_API bool llama_kv_cache_can_shift (const struct llama_kv_cache * kv);
717+
718+ // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
719+ LLAMA_API void llama_update_kv_cache (struct llama_context * ctx, struct llama_kv_cache * kv);
715720
716721 //
717722 // State / sessions
0 commit comments