]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : embeddings compatibility for OpenAI (#5190)
authorWu Jian Ping <redacted>
Mon, 29 Jan 2024 13:48:10 +0000 (21:48 +0800)
committerGitHub <redacted>
Mon, 29 Jan 2024 13:48:10 +0000 (15:48 +0200)
examples/server/oai.hpp
examples/server/server.cpp

index bc5db6eef2b551b4f6bdcad5bc1d8dcc6b849f1e..43410f803d469326f5098dd6be4fc45cde155150 100644 (file)
@@ -206,3 +206,18 @@ inline static std::vector<json> format_partial_response_oaicompat(const task_res
 
     return std::vector<json>({ret});
 }
+
+inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings)
+{
+    json res =
+        json{
+            {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
+            {"object", "list"},
+            {"usage",
+                json{{"prompt_tokens", 0},
+                     {"total_tokens", 0}}},
+            {"data", embeddings}
+        };
+    return res;
+}
+
index a48582ad9178223f1253dce2b0180834b06a0efe..11dd82c33c106992079ca1fdd8da59487f239c86 100644 (file)
@@ -2929,6 +2929,66 @@ int main(int argc, char **argv)
                 return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
             });
 
+    svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
+            {
+                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+                const json body = json::parse(req.body);
+
+                json prompt;
+                if (body.count("input") != 0)
+                {
+                    prompt = body["input"];
+                    // batch
+                    if(prompt.is_array()) {
+                        json data = json::array();
+                        int i = 0;
+                        for (const json &elem : prompt) {
+                            const int task_id = llama.queue_tasks.get_new_id();
+                            llama.queue_results.add_waiting_task_id(task_id);
+                            llama.request_completion(task_id, { {"prompt", elem}, { "n_predict", 0} }, false, true, -1);
+
+                            // get the result
+                            task_result result = llama.queue_results.recv(task_id);
+                            llama.queue_results.remove_waiting_task_id(task_id);
+
+                            json embedding = json{
+                                {"embedding", json_value(result.result_json, "embedding", json::array())},
+                                {"index", i++},
+                                {"object", "embedding"}
+                            };
+                            data.push_back(embedding);
+                        }
+                        json result = format_embeddings_response_oaicompat(body, data);
+                        return res.set_content(result.dump(), "application/json; charset=utf-8");
+                    }
+                }
+                else
+                {
+                    prompt = "";
+                }
+
+                // create and queue the task
+                const int task_id = llama.queue_tasks.get_new_id();
+                llama.queue_results.add_waiting_task_id(task_id);
+                llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}}, false, true, -1);
+
+                // get the result
+                task_result result = llama.queue_results.recv(task_id);
+                llama.queue_results.remove_waiting_task_id(task_id);
+
+                json data = json::array({json{
+                        {"embedding", json_value(result.result_json, "embedding", json::array())},
+                        {"index", 0},
+                        {"object", "embedding"}
+                    }}
+                );
+
+                json root = format_embeddings_response_oaicompat(body, data);
+
+                // send the result
+                return res.set_content(root.dump(), "application/json; charset=utf-8");
+            });
+
     // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
     //     "Bus error: 10" - this is on macOS, it does not crash on Linux
     //std::thread t2([&]()