]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Server: Use multi-task for embeddings endpoint (#6001)
authorXuan Son Nguyen <redacted>
Wed, 13 Mar 2024 10:39:11 +0000 (11:39 +0100)
committerGitHub <redacted>
Wed, 13 Mar 2024 10:39:11 +0000 (11:39 +0100)
* use multitask for embd endpoint

* specify types

* remove redundant {"n_predict", 0}

examples/server/server.cpp
examples/server/utils.hpp

index b63a6f243b6966581f8b78baed62812733f1490f..3172d96dd03539e38019d5b7898671ede4e6f28b 100644 (file)
@@ -2763,6 +2763,7 @@ int main(int argc, char ** argv) {
         res.set_header("Access-Control-Allow-Credentials", "true");
         res.set_header("Access-Control-Allow-Methods",     "POST");
         res.set_header("Access-Control-Allow-Headers",     "*");
+        return res.set_content("", "application/json; charset=utf-8");
     });
 
     svr->set_logger(log_server_request);
@@ -3371,44 +3372,37 @@ int main(int argc, char ** argv) {
         const json body = json::parse(req.body);
         bool is_openai = false;
 
-        // an input prompt can string or a list of tokens (integer)
-        std::vector<json> prompts;
+        // an input prompt can be a string or a list of tokens (integer)
+        json prompt;
         if (body.count("input") != 0) {
             is_openai = true;
-            if (body["input"].is_array()) {
-                // support multiple prompts
-                for (const json & elem : body["input"]) {
-                    prompts.push_back(elem);
-                }
-            } else {
-                // single input prompt
-                prompts.push_back(body["input"]);
-            }
+            prompt = body["input"];
         } else if (body.count("content") != 0) {
-            // only support single prompt here
-            std::string content = body["content"];
-            prompts.push_back(content);
+            // with "content", we only support single prompt
+            prompt = std::vector<std::string>{body["content"]};
         } else {
             res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
             return;
         }
 
-        // process all prompts
-        json responses = json::array();
-        for (auto & prompt : prompts) {
-            // TODO @ngxson : maybe support multitask for this endpoint?
-            // create and queue the task
+        // create and queue the task
+        json responses;
+        {
             const int id_task = ctx_server.queue_tasks.get_new_id();
-
             ctx_server.queue_results.add_waiting_task_id(id_task);
-            ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
+            ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
 
             // get the result
             server_task_result result = ctx_server.queue_results.recv(id_task);
             ctx_server.queue_results.remove_waiting_task_id(id_task);
             if (!result.error) {
-                // append to the responses
-                responses.push_back(result.data);
+                if (result.data.count("results")) {
+                    // result for multi-task
+                    responses = result.data["results"];
+                } else {
+                    // result for single task
+                    responses = std::vector<json>{result.data};
+                }
             } else {
                 // error received, ignore everything else
                 res_error(res, result.data);
@@ -3417,24 +3411,19 @@ int main(int argc, char ** argv) {
         }
 
         // write JSON response
-        json root;
-        if (is_openai) {
-            json res_oai = json::array();
-            int i = 0;
-            for (auto & elem : responses) {
-                res_oai.push_back(json{
-                    {"embedding", json_value(elem, "embedding", json::array())},
-                    {"index",     i++},
-                    {"object",    "embedding"}
-                });
-            }
-            root = format_embeddings_response_oaicompat(body, res_oai);
-        } else {
-            root = responses[0];
-        }
+        json root = is_openai
+            ? format_embeddings_response_oaicompat(body, responses)
+            : responses[0];
         return res.set_content(root.dump(), "application/json; charset=utf-8");
     };
 
+    auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
+        return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
+            res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
+            return false;
+        };
+    };
+
     //
     // Router
     //
@@ -3446,17 +3435,6 @@ int main(int argc, char ** argv) {
     }
 
     // using embedded static files
-    auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
-        return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
-            res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
-            return false;
-        };
-    };
-
-    svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
-        // TODO @ngxson : I have no idea what it is... maybe this is redundant?
-        return res.set_content("", "application/json; charset=utf-8");
-    });
     svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
     svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
     svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
index 48aeef4ebea677011f077bb4bf9468c3df9f0542..2ddb2cd21f8d6c4ce641125e6e7799294fecf7f9 100644 (file)
@@ -529,6 +529,16 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
 }
 
 static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
+    json data = json::array();
+    int i = 0;
+    for (auto & elem : embeddings) {
+        data.push_back(json{
+            {"embedding", json_value(elem, "embedding", json::array())},
+            {"index",     i++},
+            {"object",    "embedding"}
+        });
+    }
+
     json res = json {
         {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
         {"object", "list"},
@@ -536,7 +546,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
             {"prompt_tokens", 0},
             {"total_tokens", 0}
         }},
-        {"data", embeddings}
+        {"data", data}
     };
 
     return res;