44#include " llama-batch.h"
55#include " llama-cparams.h"
66#include " llama-graph.h"
7- #include " llama-model.h"
8- #include " llama-kv-cache.h"
97#include " llama-adapter.h"
108
119#include " ggml-cpp.h"
1210
1311#include < map>
1412#include < vector>
1513
14+ struct llama_model ;
15+ struct llama_kv_cache ;
16+
1617class llama_io_read_i ;
1718class llama_io_write_i ;
1819
@@ -244,28 +245,29 @@ class llama_context_base : public llama_context {
244245
245246 // Make sure enough space is available for outputs.
246247 // Returns max number of outputs for which space was reserved.
247- virtual int32_t output_reserve (int32_t n_outputs);
248+ int32_t output_reserve (int32_t n_outputs);
248249
249250 // make the outputs have the same order they had in the user-provided batch
250251 // TODO: maybe remove this
251- virtual void output_reorder ();
252+ void output_reorder ();
252253
253254 //
254255 // graph
255256 //
256257
257- virtual int32_t graph_max_nodes () const ;
258+ int32_t graph_max_nodes () const ;
258259
259260 // zero-out inputs and create the ctx_compute for the compute graph
260- virtual ggml_cgraph * graph_init ();
261+ ggml_cgraph * graph_init ();
261262
263+ // override this method in order to pass custom set of parameters to the llm_graph_context
262264 virtual llm_graph_result_ptr graph_build (
263265 ggml_context * ctx,
264266 ggml_cgraph * gf,
265267 const llama_ubatch & ubatch);
266268
267269 // returns the result of ggml_backend_sched_graph_compute_async execution
268- virtual enum ggml_status graph_compute (
270+ enum ggml_status graph_compute (
269271 ggml_cgraph * gf,
270272 bool batched);
271273
@@ -330,6 +332,8 @@ class llama_context_base : public llama_context {
330332 size_t n_token_count) override ;
331333
332334protected:
335+ // override these to store all relevant state for the specific context
336+ // TODO: read/write adapters
333337 virtual size_t state_write_data (llama_io_write_i & io);
334338 virtual size_t state_read_data (llama_io_read_i & io);
335339
@@ -345,10 +349,10 @@ class llama_context_base : public llama_context {
345349
346350 const llm_graph_type gtype;
347351
348- llama_cparams cparams;
349- llama_adapter_cvec cvec;
350- llama_loras loras;
351- llama_sbatch sbatch;
352+ llama_cparams cparams;
353+ llama_adapter_cvec cvec;
354+ llama_adapter_loras loras;
355+ llama_sbatch sbatch;
352356
353357 ggml_backend_sched_ptr sched;
354358
@@ -431,8 +435,6 @@ class llama_context_kv_self : public llama_context_base {
431435 // graph
432436 //
433437
434- ggml_cgraph * graph_init () override ;
435-
436438 llm_graph_result_ptr graph_build (
437439 ggml_context * ctx,
438440 ggml_cgraph * gf,
@@ -482,8 +484,6 @@ class llama_context_recurrent : public llama_context_base {
482484 // graph
483485 //
484486
485- ggml_cgraph * graph_init () override ;
486-
487487 llm_graph_result_ptr graph_build (
488488 ggml_context * ctx,
489489 ggml_cgraph * gf,
@@ -532,8 +532,6 @@ class llama_context_dec : public llama_context_kv_self {
532532 // graph
533533 //
534534
535- ggml_cgraph * graph_init () override ;
536-
537535 llm_graph_result_ptr graph_build (
538536 ggml_context * ctx,
539537 ggml_cgraph * gf,
@@ -677,7 +675,3 @@ class llama_context_enc_dec : public llama_context {
677675
678676 llama_cross cross;
679677};
680-
681- // For internal test use
682- // TODO: remove
683- const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map (struct llama_context * ctx);
0 commit comments