@@ -140,6 +140,7 @@ struct llama_context : public llama_graph_i {
140140
141141 virtual void input_set (const llama_ubatch & ubatch);
142142
143+ private:
143144 struct {
144145 // base input tensors
145146 ggml_tensor * tokens; // I32 [n_batch]
@@ -155,6 +156,7 @@ struct llama_context : public llama_graph_i {
155156 ggml_tensor * kq_mask_cnv; // [n_tokens, n_batch]
156157 } inp;
157158
159+ protected:
158160 //
159161 // output
160162 //
@@ -192,71 +194,71 @@ struct llama_context : public llama_graph_i {
192194 // graph build
193195 //
194196
195- virtual void build_cb (
197+ void build_cb (
196198 ggml_tensor * cur,
197199 const char * name,
198200 const llama_ubatch & ubatch,
199201 int il) override ;
200202
201203 // apply control vector for layer il
202- virtual ggml_tensor * build_cvec (
204+ ggml_tensor * build_cvec (
203205 ggml_context * ctx0,
204206 ggml_tensor * cur,
205207 int il) override ;
206208
207209 // do mat_mul, while optionally apply lora
208- virtual ggml_tensor * build_lora_mm (
210+ ggml_tensor * build_lora_mm (
209211 ggml_context * ctx0,
210212 ggml_tensor * w,
211213 ggml_tensor * cur) override ;
212214
213215 // do mat_mul_id, while optionally apply lora
214- virtual ggml_tensor * build_lora_mm_id (
216+ ggml_tensor * build_lora_mm_id (
215217 ggml_context * ctx0,
216218 ggml_tensor * w, // struct ggml_tensor * as
217219 ggml_tensor * cur, // struct ggml_tensor * b
218220 ggml_tensor * ids) override ;
219221
220- virtual ggml_tensor * build_rope_factors (int il) override ;
222+ ggml_tensor * build_rope_factors (int il) override ;
221223
222- virtual ggml_tensor * build_rope_shift (
224+ ggml_tensor * build_rope_shift (
223225 ggml_context * ctx0,
224226 ggml_tensor * cur,
225227 ggml_tensor * shift,
226228 ggml_tensor * factors,
227229 ggml_backend_buffer * bbuf) override ;
228230
229- virtual ggml_tensor * build_inp_embd (
231+ ggml_tensor * build_inp_embd (
230232 ggml_context * ctx0,
231233 ggml_tensor * tok_embd,
232234 const llama_ubatch & ubatch) override ;
233235
234- virtual ggml_tensor * build_inp_pos (
236+ ggml_tensor * build_inp_pos (
235237 ggml_context * ctx0,
236238 int32_t n_tokens) override ;
237239
238- virtual ggml_tensor * build_inp_pos_bucket (
240+ ggml_tensor * build_inp_pos_bucket (
239241 ggml_context * ctx0,
240242 int32_t n_tokens) override ;
241243
242- virtual ggml_tensor * build_inp_out_ids (
244+ ggml_tensor * build_inp_out_ids (
243245 ggml_context * ctx0) override ;
244246
245- virtual ggml_tensor * build_inp_mean (
247+ ggml_tensor * build_inp_mean (
246248 ggml_context * ctx0,
247249 int32_t n_tokens) override ;
248250
249- virtual ggml_tensor * build_inp_cls (
251+ ggml_tensor * build_inp_cls (
250252 ggml_context * ctx0,
251253 int32_t n_tokens) override ;
252254
253- virtual void build_attn_inp (
255+ void build_attn_inp (
254256 ggml_context * ctx0,
255257 int32_t n_tokens,
256258 bool causal,
257259 bool swa) override ;
258260
259- virtual ggml_tensor * build_attn (
261+ ggml_tensor * build_attn (
260262 ggml_context * ctx0,
261263 ggml_cgraph * gf,
262264 ggml_tensor * wo,
@@ -270,6 +272,9 @@ struct llama_context : public llama_graph_i {
270272 int il) override ;
271273
272274protected:
275+ virtual ggml_tensor * build_inp_self_k_shift (
276+ ggml_context * ctx0);
277+
273278 virtual void build_kv_self_shift (
274279 ggml_context * ctx0,
275280 ggml_cgraph * gf);
@@ -288,6 +293,7 @@ struct llama_context : public llama_graph_i {
288293 virtual void perf_reset ();
289294
290295protected:
296+ // TODO: become private
291297 mutable int64_t t_start_us = 0 ;
292298 mutable int64_t t_load_us = 0 ;
293299 mutable int64_t t_p_eval_us = 0 ;
@@ -346,6 +352,7 @@ struct llama_context : public llama_graph_i {
346352 //
347353 // members
348354 //
355+ // TODO: become private / move to llama_graph_i
349356
350357 const llama_model & model;
351358
@@ -412,24 +419,25 @@ class llama_context_kv_self : public llama_context {
412419 virtual ~llama_context_kv_self ();
413420
414421protected:
415- virtual void reserve () override ;
422+ void reserve () override ;
416423
417424public:
418- virtual llama_kv_cache * get_kv_self () override ;
419- virtual const llama_kv_cache * get_kv_self () const override ;
425+ llama_kv_cache * get_kv_self () override ;
426+ const llama_kv_cache * get_kv_self () const override ;
420427
421- virtual void kv_self_update () override ;
428+ void kv_self_update () override ;
422429
423- virtual int encode (llama_batch & inp_batch) override ;
424- virtual int decode (llama_batch & inp_batch) override ;
430+ int encode (llama_batch & inp_batch) override ;
431+ int decode (llama_batch & inp_batch) override ;
425432
426433protected:
427434 //
428435 // input
429436 //
430437
431- virtual void input_set (const llama_ubatch & ubatch) override ;
438+ void input_set (const llama_ubatch & ubatch) override ;
432439
440+ private:
433441 struct {
434442 ggml_tensor * self_pos_bucket; // I32 [n_kv, n_batch]
435443 ggml_tensor * self_kq_mask; // F32 [n_kv, n_batch]
@@ -443,26 +451,24 @@ class llama_context_kv_self : public llama_context {
443451 // graph
444452 //
445453
446- virtual ggml_cgraph * graph_init () override ;
454+ ggml_cgraph * graph_init () override ;
447455
448456public:
449457 //
450458 // graph build
451459 //
452460
453- virtual ggml_tensor * build_inp_self_k_shift (ggml_context * ctx0) override ;
454-
455- virtual ggml_tensor * build_inp_pos_bucket (
461+ ggml_tensor * build_inp_pos_bucket (
456462 ggml_context * ctx0,
457463 int32_t n_tokens) override ;
458464
459- virtual void build_attn_inp (
465+ void build_attn_inp (
460466 ggml_context * ctx0,
461467 int32_t n_tokens,
462468 bool causal,
463469 bool swa) override ;
464470
465- virtual ggml_tensor * build_attn (
471+ ggml_tensor * build_attn (
466472 ggml_context * ctx0,
467473 ggml_cgraph * gf,
468474 ggml_tensor * wo,
@@ -476,16 +482,22 @@ class llama_context_kv_self : public llama_context {
476482 int il) override ;
477483
478484protected:
479- virtual void build_kv_self_shift (
485+ ggml_tensor * build_inp_self_k_shift (ggml_context * ctx0) override ;
486+
487+ void build_kv_self_shift (
480488 ggml_context * ctx0,
481489 ggml_cgraph * gf) override ;
482490
483491 // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
484- virtual void build_kv_self_defrag (
492+ void build_kv_self_defrag (
485493 ggml_context * ctx0,
486494 ggml_cgraph * gf) override ;
487495
496+ // =======================================================
488497 // === encoder-decoder ===
498+ //
499+ // TODO: this is temporary here, it will be moved
500+ //
489501
490502 // whether we are computing encoder output or decoder output
491503 bool is_encoding = false ;
@@ -497,23 +509,25 @@ class llama_context_kv_self : public llama_context {
497509 struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
498510 struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch]
499511
500- virtual ggml_tensor * build_inp_embd_enc (
512+ ggml_tensor * build_inp_embd_enc (
501513 ggml_context * ctx0) override ;
502514
503- virtual ggml_tensor * build_inp_kq_mask_cross (
515+ ggml_tensor * build_inp_kq_mask_cross (
504516 ggml_context * ctx0,
505517 int32_t n_tokens) override ;
518+ // ======================================================
506519
507520 //
508521 // state save/load
509522 //
510523
511- virtual size_t state_get_data (llama_io_write_i & io) override ;
512- virtual size_t state_set_data (llama_io_read_i & io) override ;
524+ size_t state_get_data (llama_io_write_i & io) override ;
525+ size_t state_set_data (llama_io_read_i & io) override ;
513526
514- virtual size_t state_seq_get_data (llama_io_write_i & io, llama_seq_id seq_id) override ;
515- virtual size_t state_seq_set_data (llama_io_read_i & io, llama_seq_id seq_id) override ;
527+ size_t state_seq_get_data (llama_io_write_i & io, llama_seq_id seq_id) override ;
528+ size_t state_seq_set_data (llama_io_read_i & io, llama_seq_id seq_id) override ;
516529
530+ private:
517531 //
518532 // members
519533 //
@@ -532,24 +546,25 @@ class llama_context_recurrent : public llama_context {
532546 virtual ~llama_context_recurrent ();
533547
534548protected:
535- virtual void reserve () override ;
549+ void reserve () override ;
536550
537551public:
538- virtual llama_kv_cache * get_kv_self () override ;
539- virtual const llama_kv_cache * get_kv_self () const override ;
552+ llama_kv_cache * get_kv_self () override ;
553+ const llama_kv_cache * get_kv_self () const override ;
540554
541- virtual void kv_self_update () override ;
555+ void kv_self_update () override ;
542556
543- virtual int encode (llama_batch & inp_batch) override ;
544- virtual int decode (llama_batch & inp_batch) override ;
557+ int encode (llama_batch & inp_batch) override ;
558+ int decode (llama_batch & inp_batch) override ;
545559
546560protected:
547561 //
548562 // input
549563 //
550564
551- virtual void input_set (const llama_ubatch & ubatch) override ;
565+ void input_set (const llama_ubatch & ubatch) override ;
552566
567+ private:
553568 struct {
554569 ggml_tensor * s_copy; // I32 [kv_size]
555570 ggml_tensor * s_mask; // F32 [1, n_kv]
@@ -559,20 +574,20 @@ class llama_context_recurrent : public llama_context {
559574 // graph
560575 //
561576
562- virtual ggml_cgraph * graph_init () override ;
577+ ggml_cgraph * graph_init () override ;
563578
564579public:
565580 //
566581 // graph build
567582 //
568583
569- virtual ggml_tensor * build_inp_s_copy (
584+ ggml_tensor * build_inp_s_copy (
570585 ggml_context * ctx0) override ;
571586
572- virtual ggml_tensor * build_inp_s_mask (
587+ ggml_tensor * build_inp_s_mask (
573588 ggml_context * ctx0) override ;
574589
575- virtual ggml_tensor * build_copy_mask_state (
590+ ggml_tensor * build_copy_mask_state (
576591 ggml_context * ctx0,
577592 ggml_cgraph * gf,
578593 ggml_tensor * s,
@@ -581,7 +596,7 @@ class llama_context_recurrent : public llama_context {
581596 int32_t n_state,
582597 int32_t n_seqs) override ;
583598
584- virtual ggml_tensor * build_mamba_layer (
599+ ggml_tensor * build_mamba_layer (
585600 ggml_context * ctx0,
586601 ggml_cgraph * gf,
587602 ggml_tensor * cur,
@@ -590,21 +605,21 @@ class llama_context_recurrent : public llama_context {
590605 const llama_ubatch & ubatch,
591606 int il) override ;
592607
593- virtual ggml_tensor * build_rwkv_token_shift_load (
608+ ggml_tensor * build_rwkv_token_shift_load (
594609 ggml_context * ctx0,
595610 ggml_cgraph * gf,
596611 ggml_tensor * state_copy,
597612 ggml_tensor * state_mask,
598613 const llama_ubatch & ubatch,
599614 int il) override ;
600615
601- virtual ggml_tensor * build_rwkv_token_shift_store (
616+ ggml_tensor * build_rwkv_token_shift_store (
602617 ggml_context * ctx0,
603618 ggml_tensor * token_shift,
604619 const llama_ubatch & ubatch,
605620 int il) override ;
606621
607- virtual ggml_tensor * build_rwkv6_time_mix (
622+ ggml_tensor * build_rwkv6_time_mix (
608623 ggml_context * ctx0,
609624 ggml_cgraph * gf,
610625 ggml_tensor * cur,
@@ -619,12 +634,13 @@ class llama_context_recurrent : public llama_context {
619634 // state save/load
620635 //
621636
622- virtual size_t state_get_data (llama_io_write_i & io) override ;
623- virtual size_t state_set_data (llama_io_read_i & io) override ;
637+ size_t state_get_data (llama_io_write_i & io) override ;
638+ size_t state_set_data (llama_io_read_i & io) override ;
624639
625- virtual size_t state_seq_get_data (llama_io_write_i & io, llama_seq_id seq_id) override ;
626- virtual size_t state_seq_set_data (llama_io_read_i & io, llama_seq_id seq_id) override ;
640+ size_t state_seq_get_data (llama_io_write_i & io, llama_seq_id seq_id) override ;
641+ size_t state_seq_set_data (llama_io_read_i & io, llama_seq_id seq_id) override ;
627642
643+ private:
628644 //
629645 // members
630646 //
@@ -646,7 +662,7 @@ class llama_context_enc_dec : public llama_context {
646662
647663 virtual ~llama_context_enc_dec ();
648664
649- protected :
665+ private :
650666 llama_context_kv_self ctx_dec;
651667};
652668
0 commit comments