]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix first message identification (#13634)
authorDorin-Andrei Geman <redacted>
Wed, 21 May 2025 13:07:57 +0000 (16:07 +0300)
committerGitHub <redacted>
Wed, 21 May 2025 13:07:57 +0000 (15:07 +0200)
* server : fix first message identification

When using the OpenAI SDK (https://github.com/openai/openai-node/blob/master/src/lib/ChatCompletionStream.ts#L623-L626) we noticed that the expected assistant role is missing in the first streaming message. Fix this by correctly checking for the first message.

Co-authored-by: Piotr Stankiewicz <redacted>
Signed-off-by: Dorin Geman <redacted>
* server : Fix checks for first role message for stream=True

Co-authored-by: Piotr Stankiewicz <redacted>
Signed-off-by: Dorin Geman <redacted>
---------

Signed-off-by: Dorin Geman <redacted>
Co-authored-by: Piotr Stankiewicz <redacted>
tools/server/server.cpp
tools/server/tests/unit/test_chat_completion.py

index 3b1305e1a8d7f8b212afe08c7c7926c6ace113da..d48cf46e48d01af2d68a1d69abe68b058fc91a6f 100644 (file)
@@ -951,7 +951,7 @@ struct server_task_result_cmpl_partial : server_task_result {
     }
 
     json to_json_oaicompat_chat() {
-        bool first = n_decoded == 0;
+        bool first = n_decoded == 1;
         std::time_t t = std::time(0);
         json choices;
 
@@ -962,15 +962,18 @@ struct server_task_result_cmpl_partial : server_task_result {
                                             {"delta", json{{"role", "assistant"}}}}});
             } else {
                 // We have to send this as two updates to conform to openai behavior
+                // initial_ret is the role message for stream=True
                 json initial_ret = json{{"choices", json::array({json{
                                         {"finish_reason", nullptr},
                                         {"index", 0},
                                         {"delta", json{
-                                            {"role", "assistant"}
+                                            {"role", "assistant"},
+                                            {"content", ""}
                                         }}}})},
                             {"created", t},
                             {"id", oaicompat_cmpl_id},
                             {"model", oaicompat_model},
+                            {"system_fingerprint", build_info},
                             {"object", "chat.completion.chunk"}};
 
                 json second_ret = json{
@@ -982,8 +985,19 @@ struct server_task_result_cmpl_partial : server_task_result {
                             {"created", t},
                             {"id", oaicompat_cmpl_id},
                             {"model", oaicompat_model},
+                            {"system_fingerprint", build_info},
                             {"object", "chat.completion.chunk"}};
 
+                if (prob_output.probs.size() > 0) {
+                    second_ret["choices"][0]["logprobs"] = json{
+                        {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+                    };
+                }
+
+                if (timings.prompt_n >= 0) {
+                    second_ret.push_back({"timings", timings.to_json()});
+                }
+
                 return std::vector<json>({initial_ret, second_ret});
             }
         } else {
index 491cb3a5df636632ae671efec44912b7249af751..bab5d005d96c29b28325a414d7102e580353ff97 100644 (file)
@@ -71,8 +71,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
     })
     content = ""
     last_cmpl_id = None
-    for data in res:
+    for i, data in enumerate(res):
         choice = data["choices"][0]
+        if i == 0:
+            # Check first role message for stream=True
+            assert choice["delta"]["content"] == ""
+            assert choice["delta"]["role"] == "assistant"
+        else:
+            assert "role" not in choice["delta"]
         assert data["system_fingerprint"].startswith("b")
         assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
         if last_cmpl_id is None:
@@ -242,12 +248,18 @@ def test_chat_completion_with_timings_per_token():
         "stream": True,
         "timings_per_token": True,
     })
-    for data in res:
-        assert "timings" in data
-        assert "prompt_per_second" in data["timings"]
-        assert "predicted_per_second" in data["timings"]
-        assert "predicted_n" in data["timings"]
-        assert data["timings"]["predicted_n"] <= 10
+    for i, data in enumerate(res):
+        if i == 0:
+            # Check first role message for stream=True
+            assert data["choices"][0]["delta"]["content"] == ""
+            assert data["choices"][0]["delta"]["role"] == "assistant"
+        else:
+            assert "role" not in data["choices"][0]["delta"]
+            assert "timings" in data
+            assert "prompt_per_second" in data["timings"]
+            assert "predicted_per_second" in data["timings"]
+            assert "predicted_n" in data["timings"]
+            assert data["timings"]["predicted_n"] <= 10
 
 
 def test_logprobs():
@@ -295,17 +307,23 @@ def test_logprobs_stream():
     )
     output_text = ''
     aggregated_text = ''
-    for data in res:
+    for i, data in enumerate(res):
         choice = data.choices[0]
-        if choice.finish_reason is None:
-            if choice.delta.content:
-                output_text += choice.delta.content
-            assert choice.logprobs is not None
-            assert choice.logprobs.content is not None
-            for token in choice.logprobs.content:
-                aggregated_text += token.token
-                assert token.logprob <= 0.0
-                assert token.bytes is not None
-                assert token.top_logprobs is not None
-                assert len(token.top_logprobs) > 0
+        if i == 0:
+            # Check first role message for stream=True
+            assert choice.delta.content == ""
+            assert choice.delta.role == "assistant"
+        else:
+            assert choice.delta.role is None
+            if choice.finish_reason is None:
+                if choice.delta.content:
+                    output_text += choice.delta.content
+                assert choice.logprobs is not None
+                assert choice.logprobs.content is not None
+                for token in choice.logprobs.content:
+                    aggregated_text += token.token
+                    assert token.logprob <= 0.0
+                    assert token.bytes is not None
+                    assert token.top_logprobs is not None
+                    assert len(token.top_logprobs) > 0
     assert aggregated_text == output_text