]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : bring back info of final chunk in stream mode (#10722)
authorXuan Son Nguyen <redacted>
Sun, 8 Dec 2024 19:38:51 +0000 (20:38 +0100)
committerGitHub <redacted>
Sun, 8 Dec 2024 19:38:51 +0000 (20:38 +0100)
* server : bring back into to final chunk in stream mode

* clarify a bit

* traling space

examples/server/server.cpp
examples/server/tests/unit/test_completion.py

index 1c21e55aaa011d7d8462ed151cfa29bd03924d26..1d9c0533d4c404c16427e2e093a0f244df7926bf 100644 (file)
@@ -392,7 +392,7 @@ struct server_task_result {
         return false;
     }
     virtual bool is_stop() {
-        // only used by server_task_result_cmpl_partial
+        // only used by server_task_result_cmpl_*
         return false;
     }
     virtual int get_index() {
@@ -478,14 +478,20 @@ struct server_task_result_cmpl_final : server_task_result {
         return index;
     }
 
+    virtual bool is_stop() override {
+        return true; // in stream mode, final responses are considered stop
+    }
+
     virtual json to_json() override {
-        return oaicompat ? to_json_oaicompat_chat() : to_json_non_oaicompat();
+        return oaicompat
+            ? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
+            : to_json_non_oaicompat();
     }
 
     json to_json_non_oaicompat() {
         json res = json {
             {"index",               index},
-            {"content",             content},
+            {"content",             stream ? "" : content}, // in stream mode, content is already in last partial chunk
             {"id_slot",             id_slot},
             {"stop",                true},
             {"model",               oaicompat_model},
@@ -546,18 +552,46 @@ struct server_task_result_cmpl_final : server_task_result {
 
         return res;
     }
+
+    json to_json_oaicompat_chat_stream() {
+        std::time_t t = std::time(0);
+        std::string finish_reason = "length";
+        if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
+            finish_reason = "stop";
+        }
+
+        json choices = json::array({json{{"finish_reason", finish_reason},
+                                        {"index", 0},
+                                        {"delta", json::object()}}});
+
+        json ret = json {
+            {"choices", choices},
+            {"created", t},
+            {"id",      oaicompat_cmpl_id},
+            {"model",   oaicompat_model},
+            {"object",  "chat.completion.chunk"},
+            {"usage", json {
+                {"completion_tokens", n_decoded},
+                {"prompt_tokens",     n_prompt_tokens},
+                {"total_tokens",      n_decoded + n_prompt_tokens},
+            }},
+        };
+
+        if (timings.prompt_n >= 0) {
+            ret.push_back({"timings", timings.to_json()});
+        }
+
+        return ret;
+    }
 };
 
 struct server_task_result_cmpl_partial : server_task_result {
     int index = 0;
     std::string content;
 
-    bool truncated;
     int32_t n_decoded;
     int32_t n_prompt_tokens;
 
-    stop_type stop = STOP_TYPE_NONE;
-
     std::vector<completion_token_output> probs_output;
     result_timings timings;
 
@@ -573,20 +607,19 @@ struct server_task_result_cmpl_partial : server_task_result {
     }
 
     virtual bool is_stop() override {
-        return stop != STOP_TYPE_NONE;
+        return false; // in stream mode, partial responses are not considered stop
     }
 
     virtual json to_json() override {
-        if (oaicompat) {
-            return to_json_oaicompat();
-        }
-        bool is_stop = stop != STOP_TYPE_NONE;
+        return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
+    }
+
+    json to_json_non_oaicompat() {
         // non-OAI-compat JSON
         json res = json {
             {"index",            index},
             {"content",          content},
-            {"stop_type",        stop_type_to_str(stop)},
-            {"stop",             is_stop},
+            {"stop",             false},
             {"id_slot",          id_slot},
             {"tokens_predicted", n_decoded},
             {"tokens_evaluated", n_prompt_tokens},
@@ -598,72 +631,54 @@ struct server_task_result_cmpl_partial : server_task_result {
         if (!probs_output.empty()) {
             res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
         }
-        if (is_stop) {
-            res.push_back({"truncated", truncated});
-        }
         return res;
     }
 
     json to_json_oaicompat() {
         bool first = n_decoded == 0;
-
-        std::string finish_reason;
-        if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
-            finish_reason = "stop";
-        } else if (stop == STOP_TYPE_LIMIT) {
-            finish_reason = "length";
-        }
-
         std::time_t t = std::time(0);
-
         json choices;
 
-        if (!finish_reason.empty()) {
-            choices = json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"delta", json::object()}}});
-        } else {
-            if (first) {
-                if (content.empty()) {
-                    choices = json::array({json{{"finish_reason", nullptr},
-                                                {"index", 0},
-                                                {"delta", json{{"role", "assistant"}}}}});
-                } else {
-                    // We have to send this as two updates to conform to openai behavior
-                    json initial_ret = json{{"choices", json::array({json{
-                                            {"finish_reason", nullptr},
+        if (first) {
+            if (content.empty()) {
+                choices = json::array({json{{"finish_reason", nullptr},
                                             {"index", 0},
-                                            {"delta", json{
-                                                {"role", "assistant"}
-                                            }}}})},
-                                {"created", t},
-                                {"id", oaicompat_cmpl_id},
-                                {"model", oaicompat_model},
-                                {"object", "chat.completion.chunk"}};
-
-                    json second_ret = json{
-                                {"choices", json::array({json{{"finish_reason", nullptr},
-                                                                {"index", 0},
-                                                                {"delta", json{
-                                                                {"content", content}}}
-                                                                }})},
-                                {"created", t},
-                                {"id", oaicompat_cmpl_id},
-                                {"model", oaicompat_model},
-                                {"object", "chat.completion.chunk"}};
-
-                    return std::vector<json>({initial_ret, second_ret});
-                }
+                                            {"delta", json{{"role", "assistant"}}}}});
             } else {
-                choices = json::array({json{
-                    {"finish_reason", nullptr},
-                    {"index", 0},
-                    {"delta",
-                    json{
-                        {"content", content},
-                    }},
-                }});
+                // We have to send this as two updates to conform to openai behavior
+                json initial_ret = json{{"choices", json::array({json{
+                                        {"finish_reason", nullptr},
+                                        {"index", 0},
+                                        {"delta", json{
+                                            {"role", "assistant"}
+                                        }}}})},
+                            {"created", t},
+                            {"id", oaicompat_cmpl_id},
+                            {"model", oaicompat_model},
+                            {"object", "chat.completion.chunk"}};
+
+                json second_ret = json{
+                            {"choices", json::array({json{{"finish_reason", nullptr},
+                                                            {"index", 0},
+                                                            {"delta", json{
+                                                            {"content", content}}}
+                                                            }})},
+                            {"created", t},
+                            {"id", oaicompat_cmpl_id},
+                            {"model", oaicompat_model},
+                            {"object", "chat.completion.chunk"}};
+
+                return std::vector<json>({initial_ret, second_ret});
             }
+        } else {
+            choices = json::array({json{
+                {"finish_reason", nullptr},
+                {"index", 0},
+                {"delta",
+                json{
+                    {"content", content},
+                }},
+            }});
         }
 
         json ret = json {
@@ -678,14 +693,6 @@ struct server_task_result_cmpl_partial : server_task_result {
             ret.push_back({"timings", timings.to_json()});
         }
 
-        if (!finish_reason.empty()) {
-            ret.push_back({"usage", json {
-                {"completion_tokens", n_decoded},
-                {"prompt_tokens",     n_prompt_tokens},
-                {"total_tokens",      n_decoded + n_prompt_tokens},
-            }});
-        }
-
         return std::vector<json>({ret});
     }
 };
@@ -1888,12 +1895,9 @@ struct server_context {
         res->index   = slot.index;
         res->content = tkn.text_to_send;
 
-        res->truncated       = slot.truncated;
         res->n_decoded       = slot.n_decoded;
         res->n_prompt_tokens = slot.n_prompt_tokens;
 
-        res->stop = slot.stop;
-
         res->verbose           = slot.params.verbose;
         res->oaicompat         = slot.params.oaicompat;
         res->oaicompat_chat    = slot.params.oaicompat_chat;
@@ -1924,12 +1928,6 @@ struct server_context {
     }
 
     void send_final_response(server_slot & slot) {
-        if (slot.params.stream) {
-            // if in stream mode, send the last partial response
-            send_partial_response(slot, {0, "", {}});
-            return;
-        }
-
         auto res = std::make_unique<server_task_result_cmpl_final>();
         res->id              = slot.id_task;
         res->id_slot         = slot.id;
@@ -1948,6 +1946,7 @@ struct server_context {
         res->stop            = slot.stop;
 
         res->verbose           = slot.params.verbose;
+        res->stream            = slot.params.stream;
         res->oaicompat         = slot.params.oaicompat;
         res->oaicompat_chat    = slot.params.oaicompat_chat;
         res->oaicompat_model   = slot.params.oaicompat_model;
@@ -2100,7 +2099,10 @@ struct server_context {
                 return;
             }
 
-            GGML_ASSERT(dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr);
+            GGML_ASSERT(
+                dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
+                || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
+            );
             if (!result_handler(result)) {
                 cancel_tasks(id_tasks);
                 break;
index 1c3aa77de5bba5fe5cad863ba6b54da0f3975536..7f4f9cd038be4ba8d2785a2bd21672d54f73f865 100644 (file)
@@ -42,10 +42,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
     })
     content = ""
     for data in res:
+        assert "stop" in data and type(data["stop"]) == bool
         if data["stop"]:
             assert data["timings"]["prompt_n"] == n_prompt
             assert data["timings"]["predicted_n"] == n_predicted
             assert data["truncated"] == truncated
+            assert data["stop_type"] == "limit"
+            assert "generation_settings" in data
+            assert server.n_predict is not None
+            assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
+            assert data["generation_settings"]["seed"] == server.seed
             assert match_regex(re_content, content)
         else:
             content += data["content"]