]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : handle failures to restore host cache (#17078)
authorGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 12:27:05 +0000 (14:27 +0200)
committerGitHub <redacted>
Sun, 9 Nov 2025 12:27:05 +0000 (14:27 +0200)
* server : handle failures to restore host cache

* server : add tests for the prompt cache

tools/server/server.cpp
tools/server/tests/unit/test_completion.py

index 9d91e32d1fbfbc3e4b901463fa9a5fbcc88ff2c9..6bd4be3cc17c4699e0d4d4808850c9ef1d4965b8 100644 (file)
@@ -1690,6 +1690,9 @@ struct server_slot {
         bool res = prompt_cache.load(prompt, tokens, ctx, id);
         if (!res) {
             SLT_WRN(*this, "%s", "failed to load prompt from cache\n");
+
+            llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1);
+            prompt.tokens.clear();
         }
     }
 
index 3c0ce98973f4b8aa84506c7fad74534ba3aab387..ef1757db21f7f472ff99c674536f40089443d043 100644 (file)
@@ -1,6 +1,8 @@
 import pytest
 import requests
 import time
+import random
+
 from openai import OpenAI
 from utils import *
 
@@ -564,3 +566,43 @@ def test_cancel_request():
     time.sleep(1) # wait for HTTP_POLLING_SECONDS
     res = server.make_request("GET", "/slots")
     assert res.body[0]["is_processing"] == False
+
+
+# this test exercises the host-memory prompt cache
+# ref: https://github.com/ggml-org/llama.cpp/pull/16391
+# ref: https://github.com/ggml-org/llama.cpp/pull/17078
+def test_completion_prompt_cache():
+    global server
+    server.n_slots = 2
+    server.kv_unified = True
+    server.start()
+
+    for _ in range(16):
+        # generate alternating random prompts with variable lengths in order to get them in and out of the cache
+        r = random.randint(0, 4)
+        prompt = (" Hello " +  str(r)) * (40 + r)
+        n_prompt = (40 + r)*5 + 2
+        n_predict = random.randint(1, 8)
+
+        res = server.make_request(
+            "POST",
+            "/completion",
+            data={
+                "prompt": prompt,
+                "n_predict": n_predict,
+            },
+        )
+
+        assert res.status_code == 200
+        assert "content" in res.body
+        content = res.body["content"]
+        assert isinstance(content, str)
+        assert len(content) > 0
+
+        assert type(res.body["has_new_line"]) == bool
+        assert "timings" in res.body
+        timings = res.body["timings"]
+
+        assert "prompt_n" in timings and timings["prompt_n"] + timings["cache_n"] == n_prompt
+        assert "predicted_n" in timings and timings["predicted_n"] == n_predict
+        assert "tokens" in res.body and isinstance(res.body["tokens"], list)