]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add "/chat/completions" alias for "/v1/...` (#5722)
authorJorge A <redacted>
Wed, 28 Feb 2024 08:39:15 +0000 (01:39 -0700)
committerGitHub <redacted>
Wed, 28 Feb 2024 08:39:15 +0000 (10:39 +0200)
* Add "/chat/completions" as alias for "/v1/chat/completions"

* merge to upstream master

* minor : fix trailing whitespace

---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/server.cpp
examples/server/tests/features/parallel.feature
examples/server/tests/features/steps/steps.py

index 846ef7e5fee4f556567292118488fe343082a19c..6b3ee531cfb57841f543be7f33253a9409f066fa 100644 (file)
@@ -3211,87 +3211,88 @@ int main(int argc, char **argv)
                 res.set_content(models.dump(), "application/json; charset=utf-8");
             });
 
+    const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
+    {
+        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+        if (!validate_api_key(req, res)) {
+            return;
+        }
+        json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
 
-    // TODO: add mount point without "/v1" prefix -- how?
-    svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
-            {
-                res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
-                if (!validate_api_key(req, res)) {
-                    return;
-                }
-                json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
-
-                const int task_id = llama.queue_tasks.get_new_id();
-                llama.queue_results.add_waiting_task_id(task_id);
-                llama.request_completion(task_id, data, false, false, -1);
+        const int task_id = llama.queue_tasks.get_new_id();
+        llama.queue_results.add_waiting_task_id(task_id);
+        llama.request_completion(task_id, data, false, false, -1);
 
-                if (!json_value(data, "stream", false)) {
-                    std::string completion_text;
-                    task_result result = llama.queue_results.recv(task_id);
+        if (!json_value(data, "stream", false)) {
+            std::string completion_text;
+            task_result result = llama.queue_results.recv(task_id);
 
-                    if (!result.error && result.stop) {
-                        json oaicompat_result = format_final_response_oaicompat(data, result);
+            if (!result.error && result.stop) {
+                json oaicompat_result = format_final_response_oaicompat(data, result);
 
-                        res.set_content(oaicompat_result.dump(-1, ' ', false,
-                                            json::error_handler_t::replace),
-                                            "application/json; charset=utf-8");
-                    } else {
-                        res.status = 500;
-                        res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
-                    }
-                    llama.queue_results.remove_waiting_task_id(task_id);
-                } else {
-                    const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
-                        while (true) {
-                            task_result llama_result = llama.queue_results.recv(task_id);
-                            if (!llama_result.error) {
-                                std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
+                res.set_content(oaicompat_result.dump(-1, ' ', false,
+                                    json::error_handler_t::replace),
+                                    "application/json; charset=utf-8");
+            } else {
+                res.status = 500;
+                res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
+            }
+            llama.queue_results.remove_waiting_task_id(task_id);
+        } else {
+            const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
+                while (true) {
+                    task_result llama_result = llama.queue_results.recv(task_id);
+                    if (!llama_result.error) {
+                        std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
 
-                                for (auto it = result_array.begin(); it != result_array.end(); ++it)
-                                {
-                                    if (!it->empty()) {
-                                        const std::string str =
-                                            "data: " +
-                                            it->dump(-1, ' ', false, json::error_handler_t::replace) +
-                                            "\n\n";
-                                        LOG_VERBOSE("data stream", {{"to_send", str}});
-                                        if (!sink.write(str.c_str(), str.size())) {
-                                            llama.queue_results.remove_waiting_task_id(task_id);
-                                            return false;
-                                        }
-                                    }
-                                }
-                                if (llama_result.stop) {
-                                    break;
-                                }
-                            } else {
+                        for (auto it = result_array.begin(); it != result_array.end(); ++it)
+                        {
+                            if (!it->empty()) {
                                 const std::string str =
-                                    "error: " +
-                                    llama_result.result_json.dump(-1, ' ', false,
-                                            json::error_handler_t::replace) +
+                                    "data: " +
+                                    it->dump(-1, ' ', false, json::error_handler_t::replace) +
                                     "\n\n";
                                 LOG_VERBOSE("data stream", {{"to_send", str}});
                                 if (!sink.write(str.c_str(), str.size())) {
                                     llama.queue_results.remove_waiting_task_id(task_id);
                                     return false;
                                 }
-                                break;
                             }
                         }
-                        sink.done();
-                        llama.queue_results.remove_waiting_task_id(task_id);
-                        return true;
-                    };
+                        if (llama_result.stop) {
+                            break;
+                        }
+                    } else {
+                        const std::string str =
+                            "error: " +
+                            llama_result.result_json.dump(-1, ' ', false,
+                                    json::error_handler_t::replace) +
+                            "\n\n";
+                        LOG_VERBOSE("data stream", {{"to_send", str}});
+                        if (!sink.write(str.c_str(), str.size())) {
+                            llama.queue_results.remove_waiting_task_id(task_id);
+                            return false;
+                        }
+                        break;
+                    }
+                }
+                sink.done();
+                llama.queue_results.remove_waiting_task_id(task_id);
+                return true;
+            };
 
-                    auto on_complete = [task_id, &llama](bool) {
-                        // cancel request
-                        llama.request_cancel(task_id);
-                        llama.queue_results.remove_waiting_task_id(task_id);
-                    };
+            auto on_complete = [task_id, &llama](bool) {
+                // cancel request
+                llama.request_cancel(task_id);
+                llama.queue_results.remove_waiting_task_id(task_id);
+            };
 
-                    res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
-                }
-            });
+            res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
+        }
+    };
+
+    svr.Post("/chat/completions", chat_completions);
+    svr.Post("/v1/chat/completions", chat_completions);
 
     svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
             {
index c85f9de1d9a5259c9fe9cfe5623c83fe06bc6ec0..5f895cf90b9668995edee6da78b51c24ade99927 100644 (file)
@@ -54,6 +54,28 @@ Feature: Parallel
       | disabled  | 128       |
       | enabled   | 64        |
 
+  Scenario Outline: Multi users OAI completions compatibility no v1
+    Given a system prompt You are a writer.
+    And   a model tinyllama-2
+    Given a prompt:
+      """
+      Write a very long book.
+      """
+    And a prompt:
+      """
+      Write another a poem.
+      """
+    And <n_predict> max tokens to predict
+    And streaming is <streaming>
+    Given concurrent OAI completions requests no v1
+    Then the server is busy
+    Then the server is idle
+    Then all prompts are predicted with <n_predict> tokens
+    Examples:
+      | streaming | n_predict |
+      | disabled  | 128       |
+      | enabled   | 64        |
+
   Scenario:  Multi users with total number of tokens to predict exceeds the KV Cache size #3969
     Given a prompt:
       """
index ad87fcb820aa8f856d8ad9467b3599de32745970..381da105e279e5efd1712b4c321fe5eb49ee72df 100644 (file)
@@ -231,6 +231,7 @@ async def step_oai_chat_completions(context, api_error):
     completion = await oai_chat_completions(context.prompts.pop(),
                                             context.system_prompt,
                                             context.base_url,
+                                            '/v1/chat',
                                             False,
                                             model=context.model if hasattr(context, 'model') else None,
 
@@ -288,6 +289,28 @@ async def step_oai_chat_completions(context):
                               # user_prompt is inserted automatically
                               context.system_prompt,
                               context.base_url,
+                              '/v1/chat/completions',
+                              True,  # async_client
+                              model=context.model
+                              if hasattr(context, 'model') else None,
+                              n_predict=context.n_predict
+                              if hasattr(context, 'n_predict') else None,
+                              enable_streaming=context.enable_streaming
+                              if hasattr(context, 'enable_streaming') else None,
+                              server_seed=context.server_seed
+                              if hasattr(context, 'server_seed') else None,
+                              user_api_key=context.user_api_key
+                              if hasattr(context, 'user_api_key') else None)
+
+
+@step(u'concurrent OAI completions requests no v1')
+@async_run_until_complete
+async def step_oai_chat_completions(context):
+    await concurrent_requests(context, oai_chat_completions,
+                              # user_prompt is inserted automatically
+                              context.system_prompt,
+                              context.base_url,
+                              '/chat/completions',
                               True,  # async_client
                               model=context.model
                               if hasattr(context, 'model') else None,
@@ -497,6 +520,7 @@ async def request_completion(prompt,
 async def oai_chat_completions(user_prompt,
                                system_prompt,
                                base_url,
+                               base_path,
                                async_client,
                                debug=False,
                                model=None,
@@ -537,7 +561,7 @@ async def oai_chat_completions(user_prompt,
         origin = 'llama.cpp'
         headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
         async with aiohttp.ClientSession() as session:
-            async with session.post(f'{base_url}/v1/chat/completions',
+            async with session.post(f'{base_url}{base_path}',
                                     json=payload,
                                     headers=headers) as response:
                 if enable_streaming:
@@ -579,7 +603,7 @@ async def oai_chat_completions(user_prompt,
     else:
         try:
             openai.api_key = user_api_key
-            openai.api_base = f'{base_url}/v1/chat'
+            openai.api_base = f'{base_url}{base_path}'
             chat_completion = openai.Completion.create(
                 messages=payload['messages'],
                 model=model,