]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix cache_tokens bug with no cache_prompt (#13533)
authorXuan-Son Nguyen <redacted>
Wed, 14 May 2025 11:35:07 +0000 (13:35 +0200)
committerGitHub <redacted>
Wed, 14 May 2025 11:35:07 +0000 (13:35 +0200)
tools/server/server.cpp
tools/server/tests/unit/test_completion.py
tools/server/utils.hpp

index 7169ffdceebf9a43accd37468cd76b80e28e10fb..a9b99d437e2fd237912929b2399400052197e39a 100644 (file)
@@ -2951,7 +2951,8 @@ struct server_context {
                 llama_kv_self_seq_rm (ctx, slot.id, n_keep            , n_keep + n_discard);
                 llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past,        -n_discard);
 
-                if (slot.params.cache_prompt) {
+                // add generated tokens to cache
+                {
                     llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
                     for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
                         new_tokens[i - n_discard] = new_tokens[i];
@@ -2996,10 +2997,7 @@ struct server_context {
             common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
 
             slot.n_past += 1;
-
-            if (slot.params.cache_prompt) {
-                slot.cache_tokens.push_back(slot.sampled);
-            }
+            slot.cache_tokens.push_back(slot.sampled);
 
             SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
                     slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
@@ -3171,6 +3169,11 @@ struct server_context {
 
                                     SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
                                 }
+                            } else {
+                                // if we don't cache the prompt, we have to remove the entire KV cache
+                                llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
+                                slot.n_past = 0;
+                                slot.cache_tokens.clear();
                             }
                         }
 
@@ -3204,7 +3207,7 @@ struct server_context {
                     SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
 
                     // remove the non-common part from the cache
-                    slot.cache_tokens.resize(slot.n_past);
+                    slot.cache_tokens.keep_first(slot.n_past);
 
                     // check if we should process the image
                     if (slot.n_past < slot.n_prompt_tokens
@@ -3221,7 +3224,8 @@ struct server_context {
                             continue;
                         }
 
-                        if (slot.params.cache_prompt) {
+                        // add the image chunk to cache
+                        {
                             const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
                             slot.cache_tokens.push_back(chunk.get()); // copy
                         }
@@ -3242,9 +3246,7 @@ struct server_context {
                         const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
 
                         common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
-                        if (slot.params.cache_prompt) {
-                            slot.cache_tokens.push_back(cur_tok);
-                        }
+                        slot.cache_tokens.push_back(cur_tok);
 
                         slot.n_prompt_tokens_processed++;
                         slot.n_past++;
index 0ed5b99bef4e4a3c7abd2fee026960c859806802..4099c4e25cd6e8774d267495d87b9c508a1a489e 100644 (file)
@@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt():
     assert res_cache.body["content"] == res_no_cache.body["content"]
 
 
+def test_nocache_long_input_prompt():
+    global server
+    server.start()
+    res = server.make_request("POST", "/completion", data={
+        "prompt": "I believe the meaning of life is"*32,
+        "seed": 42,
+        "temperature": 1.0,
+        "cache_prompt": False,
+    })
+    assert res.status_code == 200
+
+
 def test_completion_with_tokens_input():
     global server
     server.temperature = 0.0
index b8d140e3f051cc79b20fe2abc6fffb3b6c25b62f..45193c17cfd98dc5a3ddf03d1e9697b2bcd331e6 100644 (file)
@@ -1153,7 +1153,7 @@ public:
         tokens.clear();
     }
 
-    void resize(size_t n) {
+    void keep_first(size_t n) {
         GGML_ASSERT(n <= tokens.size());
         if (has_mtmd) {
             // we throw an error if we try to remove a token in the middle of an image