@@ -32,6 +32,38 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
3232 return relative_bucket;
3333}
3434
35+ enum ggml_status llama_context::compute_graph (
36+ ggml_cgraph * graph,
37+ bool batched) {
38+ int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads ;
39+ ggml_threadpool_t tp = batched ? threadpool_batch : threadpool;
40+
41+ if (backend_cpu != nullptr ) {
42+ auto * reg = ggml_backend_dev_backend_reg (ggml_backend_get_device (backend_cpu));
43+ auto * set_threadpool_fn = (decltype (ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address (reg, " ggml_backend_cpu_set_threadpool" );
44+ set_threadpool_fn (backend_cpu, tp);
45+ }
46+
47+ // set the number of threads for all the backends
48+ for (const auto & set_n_threads_fn : set_n_threads_fns) {
49+ set_n_threads_fn.second (set_n_threads_fn.first , n_threads);
50+ }
51+
52+ auto status = ggml_backend_sched_graph_compute_async (sched.get (), graph);
53+ if (status != GGML_STATUS_SUCCESS) {
54+ LLAMA_LOG_ERROR (" %s: ggml_backend_sched_graph_compute_async failed with error %d\n " , __func__, status);
55+ }
56+
57+ // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
58+
59+ return status;
60+ }
61+
62+
63+ llama_pos llama_context::pos_max () const {
64+ return kv_self.pos_max ();
65+ }
66+
3567// TODO: improve
3668void llama_context::reset () {
3769 inp_tokens = nullptr ;
@@ -540,6 +572,93 @@ ggml_tensor * llama_context::build_lora_mm_id(
540572 return res;
541573}
542574
575+ bool llama_context::kv_self_update () {
576+ bool need_reserve = false ;
577+
578+ auto & kv = kv_self;
579+
580+ if (kv.has_shift ) {
581+ if (!kv.can_shift ) {
582+ GGML_ABORT (" The current context does not support K-shift" );
583+ }
584+
585+ // apply K-shift if needed
586+ if (model.hparams .rope_type != LLAMA_ROPE_TYPE_NONE) {
587+ prepare_k_shift ();
588+
589+ ggml_backend_sched_reset (sched.get ());
590+
591+ struct ggml_init_params params = {
592+ /* .mem_size =*/ buf_compute_meta.size (),
593+ /* .mem_buffer =*/ buf_compute_meta.data (),
594+ /* .no_alloc =*/ true ,
595+ };
596+
597+ ggml_context * ctx0 = ggml_init (params);
598+
599+ reset ();
600+
601+ ggml_cgraph * gf = ggml_new_graph_custom (ctx0, model.max_nodes (), false );
602+
603+ build_k_shift (ctx0, gf);
604+
605+ ggml_backend_sched_alloc_graph (sched.get (), gf);
606+
607+ set_inputs ({});
608+
609+ compute_graph (gf, false );
610+
611+ ggml_free (ctx0);
612+
613+ need_reserve = true ;
614+ }
615+
616+ {
617+ kv.has_shift = false ;
618+
619+ for (uint32_t i = 0 ; i < kv.size ; ++i) {
620+ kv.cells [i].delta = 0 ;
621+ }
622+ }
623+ }
624+
625+ // defragment the KV cache if needed
626+ if (kv.do_defrag ) {
627+ prepare_defrag ();
628+
629+ ggml_backend_sched_reset (sched.get ());
630+
631+ struct ggml_init_params params = {
632+ /* .mem_size =*/ buf_compute_meta.size (),
633+ /* .mem_buffer =*/ buf_compute_meta.data (),
634+ /* .no_alloc =*/ true ,
635+ };
636+
637+ ggml_context * ctx0 = ggml_init (params);
638+
639+ reset ();
640+
641+ ggml_cgraph * gf = ggml_new_graph_custom (ctx0, model.max_nodes (), false );
642+
643+ build_defrag (ctx0, gf);
644+
645+ ggml_backend_sched_alloc_graph (sched.get (), gf);
646+
647+ // no input
648+ // set_inputs({});
649+
650+ compute_graph (gf, false );
651+
652+ ggml_free (ctx0);
653+
654+ need_reserve = true ;
655+
656+ kv.do_defrag = false ;
657+ }
658+
659+ return need_reserve;
660+ }
661+
543662void llama_context::build_attn_inp (
544663 ggml_context * ctx0,
545664 int32_t n_tokens,
0 commit comments