]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix speculative decoding with context shift (#10641)
authorGeorgi Gerganov <redacted>
Wed, 4 Dec 2024 20:38:20 +0000 (22:38 +0200)
committerGitHub <redacted>
Wed, 4 Dec 2024 20:38:20 +0000 (22:38 +0200)
* server : fix speculative decoding with context shift

ggml-ci

* server : take into account speculative limits

ggml-ci

* server : add tests

examples/server/server.cpp
examples/server/tests/unit/test_speculative.py

index 9bca3f30e7574b1d0e717362cda124d9b5a9982e..31dfd624080470971231e79cc17d7a51ca446015 100644 (file)
@@ -921,6 +921,8 @@ struct server_context {
         slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
 
         slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
+        slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2);
+        slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
 
         if (slot.params.sampling.dry_base < 1.0f) {
            slot.params.sampling.dry_base = defaults.sampling.dry_base;
@@ -2322,10 +2324,29 @@ struct server_context {
                     continue;
                 }
 
+                // determine the max draft that fits the current slot state
+                int n_draft_max = slot.params.speculative.n_max;
+
+                // note: n_past is not yet increased for the `id` token sampled above
+                //       also, need to leave space for 1 extra token to allow context shifts
+                n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
+
+                if (slot.n_remaining > 0) {
+                    n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
+                }
+
+                SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
+
+                if (n_draft_max < slot.params.speculative.n_min) {
+                    SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
+
+                    continue;
+                }
+
                 llama_token id = slot.sampled;
 
                 struct common_speculative_params params_spec;
-                params_spec.n_draft   = slot.params.speculative.n_max;
+                params_spec.n_draft   = n_draft_max;
                 params_spec.n_reuse   = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
                 params_spec.p_min     = slot.params.speculative.p_min;
 
@@ -2333,6 +2354,8 @@ struct server_context {
 
                 // ignore small drafts
                 if (slot.params.speculative.n_min > (int) draft.size()) {
+                    SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
+
                     continue;
                 }
 
@@ -2344,6 +2367,8 @@ struct server_context {
                     common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
                 }
 
+                SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
+
                 llama_decode(ctx, slot.batch_spec);
 
                 // the accepted tokens from the speculation
@@ -2372,7 +2397,7 @@ struct server_context {
                     }
                 }
 
-                SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
+                SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
             }
         }
 
index 982d6abb45f5f735a37fb6e0be0d06185c31b39a..3bb5733cbdf48f82912ad5f7386d5dac411bddda 100644 (file)
@@ -82,6 +82,37 @@ def test_different_draft_min_draft_max():
         last_content = res.body["content"]
 
 
+def test_slot_ctx_not_exceeded():
+    global server
+    server.n_ctx = 64
+    server.start()
+    res = server.make_request("POST", "/completion", data={
+        "prompt": "Hello " * 56,
+        "temperature": 0.0,
+        "top_k": 1,
+        "speculative.p_min": 0.0,
+    })
+    assert res.status_code == 200
+    assert len(res.body["content"]) > 0
+
+
+def test_with_ctx_shift():
+    global server
+    server.n_ctx = 64
+    server.start()
+    res = server.make_request("POST", "/completion", data={
+        "prompt": "Hello " * 56,
+        "temperature": 0.0,
+        "top_k": 1,
+        "n_predict": 64,
+        "speculative.p_min": 0.0,
+    })
+    assert res.status_code == 200
+    assert len(res.body["content"]) > 0
+    assert res.body["tokens_predicted"] == 64
+    assert res.body["truncated"] == True
+
+
 @pytest.mark.parametrize("n_slots,n_requests", [
     (1, 2),
     (2, 2),