Skip to content

Commit 7956bb4

Browse files
authored
bench : cache the llama_context state at computed depth (#16944)
* bench : cache llama_context state at depth * cont : handle failures to restore the old state * cont : print information when the state is being reused
1 parent 9008027 commit 7956bb4

File tree

1 file changed

+40
-7
lines changed

1 file changed

+40
-7
lines changed

tools/llama-bench/llama-bench.cpp

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,6 +1919,12 @@ struct sql_printer : public printer {
19191919
}
19201920
};
19211921

1922+
struct ctx_state {
1923+
int depth = 0; // in tokens
1924+
1925+
std::vector<uint8_t> buf; // the llama_context state buffer
1926+
};
1927+
19221928
static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
19231929
llama_set_n_threads(ctx, n_threads, n_threads);
19241930

@@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) {
20512057
llama_model * lmodel = nullptr;
20522058
const cmd_params_instance * prev_inst = nullptr;
20532059

2060+
// store the llama_context state at the previous depth that we performed a test
2061+
// ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721
2062+
ctx_state cstate;
2063+
20542064
int params_idx = 0;
20552065
auto params_count = params_instances.size();
20562066
for (const auto & inst : params_instances) {
@@ -2134,14 +2144,37 @@ int main(int argc, char ** argv) {
21342144
llama_memory_clear(llama_get_memory(ctx), false);
21352145

21362146
if (t.n_depth > 0) {
2137-
if (params.progress) {
2138-
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
2139-
i + 1, params.reps);
2147+
bool is_cached = t.n_depth == cstate.depth;
2148+
2149+
if (is_cached) {
2150+
// if previously we have computed at this depth, just restore the state
2151+
const size_t ret = llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
2152+
if (ret == 0) {
2153+
// if the old state is incompatible with the current context - reprocess from scratch
2154+
is_cached = false;
2155+
}
21402156
}
2141-
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
2142-
if (!res) {
2143-
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
2144-
exit(1);
2157+
2158+
if (!is_cached) {
2159+
if (params.progress) {
2160+
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
2161+
i + 1, params.reps);
2162+
}
2163+
bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
2164+
if (!res) {
2165+
fprintf(stderr, "%s: error: failed to run depth\n", __func__);
2166+
exit(1);
2167+
}
2168+
2169+
// store the context state for reuse in later runs
2170+
cstate.depth = t.n_depth;
2171+
cstate.buf.resize(llama_state_seq_get_size(ctx, 0));
2172+
llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
2173+
} else {
2174+
if (params.progress) {
2175+
fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n", params_idx, params_count,
2176+
i + 1, params.reps);
2177+
}
21452178
}
21462179
}
21472180

0 commit comments

Comments
 (0)