@@ -1919,6 +1919,12 @@ struct sql_printer : public printer {
19191919 }
19201920};
19211921
1922+ struct ctx_state {
1923+ int depth = 0 ; // in tokens
1924+
1925+ std::vector<uint8_t > buf; // the llama_context state buffer
1926+ };
1927+
19221928static bool test_prompt (llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
19231929 llama_set_n_threads (ctx, n_threads, n_threads);
19241930
@@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) {
20512057 llama_model * lmodel = nullptr ;
20522058 const cmd_params_instance * prev_inst = nullptr ;
20532059
2060+ // store the llama_context state at the previous depth that we performed a test
2061+ // ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721
2062+ ctx_state cstate;
2063+
20542064 int params_idx = 0 ;
20552065 auto params_count = params_instances.size ();
20562066 for (const auto & inst : params_instances) {
@@ -2134,14 +2144,37 @@ int main(int argc, char ** argv) {
21342144 llama_memory_clear (llama_get_memory (ctx), false );
21352145
21362146 if (t.n_depth > 0 ) {
2137- if (params.progress ) {
2138- fprintf (stderr, " llama-bench: benchmark %d/%zu: depth run %d/%d\n " , params_idx, params_count,
2139- i + 1 , params.reps );
2147+ bool is_cached = t.n_depth == cstate.depth ;
2148+
2149+ if (is_cached) {
2150+ // if previously we have computed at this depth, just restore the state
2151+ const size_t ret = llama_state_seq_set_data (ctx, cstate.buf .data (), cstate.buf .size (), 0 );
2152+ if (ret == 0 ) {
2153+ // if the old state is incompatible with the current context - reprocess from scratch
2154+ is_cached = false ;
2155+ }
21402156 }
2141- bool res = test_prompt (ctx, t.n_depth , t.n_batch , t.n_threads );
2142- if (!res) {
2143- fprintf (stderr, " %s: error: failed to run depth\n " , __func__);
2144- exit (1 );
2157+
2158+ if (!is_cached) {
2159+ if (params.progress ) {
2160+ fprintf (stderr, " llama-bench: benchmark %d/%zu: depth run %d/%d\n " , params_idx, params_count,
2161+ i + 1 , params.reps );
2162+ }
2163+ bool res = test_prompt (ctx, t.n_depth , t.n_batch , t.n_threads );
2164+ if (!res) {
2165+ fprintf (stderr, " %s: error: failed to run depth\n " , __func__);
2166+ exit (1 );
2167+ }
2168+
2169+ // store the context state for reuse in later runs
2170+ cstate.depth = t.n_depth ;
2171+ cstate.buf .resize (llama_state_seq_get_size (ctx, 0 ));
2172+ llama_state_seq_get_data (ctx, cstate.buf .data (), cstate.buf .size (), 0 );
2173+ } else {
2174+ if (params.progress ) {
2175+ fprintf (stderr, " llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n " , params_idx, params_count,
2176+ i + 1 , params.reps );
2177+ }
21452178 }
21462179 }
21472180
0 commit comments