@@ -171,7 +171,7 @@ struct llama_context : public llama_graph_i {
171171 // graph
172172 //
173173
174- // zero-out inputs and create the ctx_context for the compute graph
174+ // zero-out inputs and create the ctx_compute for the compute graph
175175 virtual ggml_cgraph * graph_init ();
176176
177177 // TODO: add encode/decode graphs
@@ -187,73 +187,74 @@ struct llama_context : public llama_graph_i {
187187
188188 ggml_context_ptr ctx_compute;
189189
190+ public:
190191 //
191- // graph build API (generic)
192+ // graph build
192193 //
193194
194195 virtual void build_cb (
195196 ggml_tensor * cur,
196197 const char * name,
197198 const llama_ubatch & ubatch,
198- int il);
199+ int il) override ;
199200
200201 // apply control vector for layer il
201202 virtual ggml_tensor * build_cvec (
202203 ggml_context * ctx0,
203204 ggml_tensor * cur,
204- int il);
205+ int il) override ;
205206
206207 // do mat_mul, while optionally apply lora
207208 virtual ggml_tensor * build_lora_mm (
208209 ggml_context * ctx0,
209210 ggml_tensor * w,
210- ggml_tensor * cur);
211+ ggml_tensor * cur) override ;
211212
212213 // do mat_mul_id, while optionally apply lora
213214 virtual ggml_tensor * build_lora_mm_id (
214215 ggml_context * ctx0,
215216 ggml_tensor * w, // struct ggml_tensor * as
216217 ggml_tensor * cur, // struct ggml_tensor * b
217- ggml_tensor * ids);
218+ ggml_tensor * ids) override ;
218219
219- virtual ggml_tensor * build_rope_factors (int il);
220+ virtual ggml_tensor * build_rope_factors (int il) override ;
220221
221222 virtual ggml_tensor * build_rope_shift (
222223 ggml_context * ctx0,
223224 ggml_tensor * cur,
224225 ggml_tensor * shift,
225226 ggml_tensor * factors,
226- ggml_backend_buffer * bbuf);
227+ ggml_backend_buffer * bbuf) override ;
227228
228229 virtual ggml_tensor * build_inp_embd (
229230 ggml_context * ctx0,
230231 ggml_tensor * tok_embd,
231- const llama_ubatch & ubatch);
232+ const llama_ubatch & ubatch) override ;
232233
233234 virtual ggml_tensor * build_inp_pos (
234235 ggml_context * ctx0,
235- int32_t n_tokens);
236+ int32_t n_tokens) override ;
236237
237238 virtual ggml_tensor * build_inp_pos_bucket (
238239 ggml_context * ctx0,
239- int32_t n_tokens);
240+ int32_t n_tokens) override ;
240241
241242 virtual ggml_tensor * build_inp_out_ids (
242- ggml_context * ctx0);
243+ ggml_context * ctx0) override ;
243244
244245 virtual ggml_tensor * build_inp_mean (
245246 ggml_context * ctx0,
246- int32_t n_tokens);
247+ int32_t n_tokens) override ;
247248
248249 virtual ggml_tensor * build_inp_cls (
249250 ggml_context * ctx0,
250- int32_t n_tokens);
251+ int32_t n_tokens) override ;
251252
252253 virtual void build_attn_inp (
253254 ggml_context * ctx0,
254255 int32_t n_tokens,
255256 bool causal,
256- bool swa);
257+ bool swa) override ;
257258
258259 virtual ggml_tensor * build_attn (
259260 ggml_context * ctx0,
@@ -266,7 +267,17 @@ struct llama_context : public llama_graph_i {
266267 ggml_tensor * kq_b,
267268 int32_t n_tokens,
268269 float kq_scale,
269- int il);
270+ int il) override ;
271+
272+ protected:
273+ virtual void build_kv_self_shift (
274+ ggml_context * ctx0,
275+ ggml_cgraph * gf);
276+
277+ // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
278+ virtual void build_kv_self_defrag (
279+ ggml_context * ctx0,
280+ ggml_cgraph * gf);
270281
271282public:
272283 //
@@ -434,6 +445,7 @@ class llama_context_kv_self : public llama_context {
434445
435446 virtual ggml_cgraph * graph_init () override ;
436447
448+ public:
437449 //
438450 // graph build
439451 //
@@ -463,6 +475,7 @@ class llama_context_kv_self : public llama_context {
463475 float kq_scale,
464476 int il) override ;
465477
478+ protected:
466479 virtual void build_kv_self_shift (
467480 ggml_context * ctx0,
468481 ggml_cgraph * gf) override ;
@@ -548,6 +561,7 @@ class llama_context_recurrent : public llama_context {
548561
549562 virtual ggml_cgraph * graph_init () override ;
550563
564+ public:
551565 //
552566 // graph build
553567 //
@@ -600,6 +614,7 @@ class llama_context_recurrent : public llama_context {
600614 const llama_ubatch & ubatch,
601615 int il) override ;
602616
617+ protected:
603618 //
604619 // state save/load
605620 //
0 commit comments