]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: fix regression on streamed non-chat completion w/ stops (#13785)
authorOlivier Chafik <redacted>
Mon, 26 May 2025 13:16:37 +0000 (06:16 -0700)
committerGitHub <redacted>
Mon, 26 May 2025 13:16:37 +0000 (14:16 +0100)
* more forgiving message diffs: partial stop words aren't erased, full stops are

* Add (slow) server test for completion + stream + stop

common/chat.cpp
tools/server/tests/unit/test_completion.py

index adfe51db5a7704faa577f35a97ccbcfc59e48198..c2379f669dc89b13470cf8514f46bab36aae1a1b 100644 (file)
@@ -31,6 +31,11 @@ static std::string string_diff(const std::string & last, const std::string & cur
         return current;
     }
     if (!string_starts_with(current, last)) {
+        if (string_starts_with(last, current)) {
+            // This happens if the last generation ended on a partial stop word (not erased),
+            // and the current ended on a stop word (erased).
+            return "";
+        }
         throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'");
     }
     return current.substr(last.size());
index 4099c4e25cd6e8774d267495d87b9c508a1a489e..f6909e9ae788438ca3125da421642c58e1f83c46 100644 (file)
@@ -121,6 +121,30 @@ def test_completion_stream_with_openai_library():
     assert match_regex("(going|bed)+", output_text)
 
 
+# Test case from https://github.com/ggml-org/llama.cpp/issues/13780
+@pytest.mark.slow
+def test_completion_stream_with_openai_library_stops():
+    global server
+    server.model_hf_repo = "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M"
+    server.model_hf_file = None
+    server.start()
+    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
+    res = client.completions.create(
+        model="davinci-002",
+        prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n",
+        stop=["User:\n", "Assistant:\n"],
+        max_tokens=200,
+        stream=True,
+    )
+    output_text = ''
+    for data in res:
+        choice = data.choices[0]
+        if choice.finish_reason is None:
+            assert choice.text is not None
+            output_text += choice.text
+    assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}'
+
+
 @pytest.mark.parametrize("n_slots", [1, 2])
 def test_consistent_result_same_seed(n_slots: int):
     global server