]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : maintain chat completion id for streaming responses (#5988)
authorMinsoo Cheong <redacted>
Mon, 11 Mar 2024 08:09:32 +0000 (17:09 +0900)
committerGitHub <redacted>
Mon, 11 Mar 2024 08:09:32 +0000 (10:09 +0200)
* server: maintain chat completion id for streaming responses

* Update examples/server/utils.hpp

* Update examples/server/utils.hpp

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/server.cpp
examples/server/utils.hpp

index c7d3ed01b63470e86267e3a465b1b3d528d6460e..3951507aa5fbea102230355cae5b69d1d8b456d4 100644 (file)
@@ -3195,11 +3195,12 @@ int main(int argc, char ** argv) {
         ctx_server.queue_results.add_waiting_task_id(id_task);
         ctx_server.request_completion(id_task, -1, data, false, false);
 
+        const auto completion_id = gen_chatcmplid();
         if (!json_value(data, "stream", false)) {
             server_task_result result = ctx_server.queue_results.recv(id_task);
 
             if (!result.error && result.stop) {
-                json result_oai = format_final_response_oaicompat(data, result.data);
+                json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
 
                 res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
             } else {
@@ -3208,11 +3209,11 @@ int main(int argc, char ** argv) {
             }
             ctx_server.queue_results.remove_waiting_task_id(id_task);
         } else {
-            const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
+            const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
                 while (true) {
                     server_task_result result = ctx_server.queue_results.recv(id_task);
                     if (!result.error) {
-                        std::vector<json> result_array = format_partial_response_oaicompat(result.data);
+                        std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
 
                         for (auto it = result_array.begin(); it != result_array.end(); ++it) {
                             if (!it->empty()) {
index df0a27782e6460c2d56395d1130eaf810dc9600e..f27af81e99ef50e67a41013503c7cffe8cf393e3 100644 (file)
@@ -378,7 +378,7 @@ static json oaicompat_completion_params_parse(
     return llama_params;
 }
 
-static json format_final_response_oaicompat(const json & request, json result, bool streaming = false) {
+static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) {
     bool stopped_word        = result.count("stopped_word") != 0;
     bool stopped_eos         = json_value(result, "stopped_eos", false);
     int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@@ -412,7 +412,7 @@ static json format_final_response_oaicompat(const json & request, json result, b
             {"prompt_tokens",     num_prompt_tokens},
             {"total_tokens",      num_tokens_predicted + num_prompt_tokens}
         }},
-        {"id", gen_chatcmplid()}
+        {"id", completion_id}
     };
 
     if (server_verbose) {
@@ -427,7 +427,7 @@ static json format_final_response_oaicompat(const json & request, json result, b
 }
 
 // return value is vector as there is one case where we might need to generate two responses
-static std::vector<json> format_partial_response_oaicompat(json result) {
+static std::vector<json> format_partial_response_oaicompat(json result, const std::string & completion_id) {
     if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
         return std::vector<json>({result});
     }
@@ -471,7 +471,7 @@ static std::vector<json> format_partial_response_oaicompat(json result) {
                                             {"role", "assistant"}
                                         }}}})},
                             {"created", t},
-                            {"id", gen_chatcmplid()},
+                            {"id", completion_id},
                             {"model", modelname},
                             {"object", "chat.completion.chunk"}};
 
@@ -482,7 +482,7 @@ static std::vector<json> format_partial_response_oaicompat(json result) {
                                                             {"content", content}}}
                                                             }})},
                             {"created", t},
-                            {"id", gen_chatcmplid()},
+                            {"id", completion_id},
                             {"model", modelname},
                             {"object", "chat.completion.chunk"}};
 
@@ -509,7 +509,7 @@ static std::vector<json> format_partial_response_oaicompat(json result) {
     json ret = json {
         {"choices", choices},
         {"created", t},
-        {"id",      gen_chatcmplid()},
+        {"id",      completion_id},
         {"model",   modelname},
         {"object",  "chat.completion.chunk"}
     };