]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
bench : cache the llama_context state at computed depth (#16944)
authorGeorgi Gerganov <redacted>
Fri, 7 Nov 2025 19:23:11 +0000 (21:23 +0200)
committerGitHub <redacted>
Fri, 7 Nov 2025 19:23:11 +0000 (21:23 +0200)
* bench : cache llama_context state at depth

* cont : handle failures to restore the old state

* cont : print information when the state is being reused

tools/llama-bench/llama-bench.cpp

index 0de07b98112681f571ca4056d90da818a69b8026..852a512451d64ca995f171795d6a085d9084b6c8 100644 (file)
@@ -1919,6 +1919,12 @@ struct sql_printer : public printer {
     }
 };
 
+struct ctx_state {
+    int depth = 0; // in tokens
+
+    std::vector<uint8_t> buf; // the llama_context state buffer
+};
+
 static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
     llama_set_n_threads(ctx, n_threads, n_threads);
 
@@ -2051,6 +2057,10 @@ int main(int argc, char ** argv) {
     llama_model *               lmodel    = nullptr;
     const cmd_params_instance * prev_inst = nullptr;
 
+    // store the llama_context state at the previous depth that we performed a test
+    // ref: https://github.com/ggml-org/llama.cpp/pull/16944#issuecomment-3478151721
+    ctx_state cstate;
+
     int  params_idx   = 0;
     auto params_count = params_instances.size();
     for (const auto & inst : params_instances) {
@@ -2134,14 +2144,37 @@ int main(int argc, char ** argv) {
             llama_memory_clear(llama_get_memory(ctx), false);
 
             if (t.n_depth > 0) {
-                if (params.progress) {
-                    fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
-                            i + 1, params.reps);
+                bool is_cached = t.n_depth == cstate.depth;
+
+                if (is_cached) {
+                    // if previously we have computed at this depth, just restore the state
+                    const size_t ret = llama_state_seq_set_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
+                    if (ret == 0) {
+                        // if the old state is incompatible with the current context - reprocess from scratch
+                        is_cached = false;
+                    }
                 }
-                bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
-                if (!res) {
-                    fprintf(stderr, "%s: error: failed to run depth\n", __func__);
-                    exit(1);
+
+                if (!is_cached) {
+                    if (params.progress) {
+                        fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
+                                i + 1, params.reps);
+                    }
+                    bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
+                    if (!res) {
+                        fprintf(stderr, "%s: error: failed to run depth\n", __func__);
+                        exit(1);
+                    }
+
+                    // store the context state for reuse in later runs
+                    cstate.depth = t.n_depth;
+                    cstate.buf.resize(llama_state_seq_get_size(ctx, 0));
+                    llama_state_seq_get_data(ctx, cstate.buf.data(), cstate.buf.size(), 0);
+                } else {
+                    if (params.progress) {
+                        fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d (cached)\n", params_idx, params_count,
+                                i + 1, params.reps);
+                    }
                 }
             }