]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Server: clean up OAI params parsing function (#6284)
authorXuan Son Nguyen <redacted>
Mon, 25 Mar 2024 08:42:17 +0000 (09:42 +0100)
committerGitHub <redacted>
Mon, 25 Mar 2024 08:42:17 +0000 (09:42 +0100)
* server: clean up oai parsing function

* fix response_format

* fix empty response_format

* minor fixes

* add TODO for logprobs

* update docs

examples/server/README.md
examples/server/server.cpp
examples/server/utils.hpp

index dfea2b9050210e6543e00b5ebdd97006f5f27396..49121a460f8c3d684a6068db28759ecdc079e147 100644 (file)
@@ -360,7 +360,7 @@ Notice that each `probs` is an array of length `n_probs`.
 - `default_generation_settings` - the default generation settings for the `/completion` endpoint, has the same fields as the `generation_settings` response object from the `/completion` endpoint.
 - `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
 
-- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only ChatML-tuned models, such as Dolphin, OpenOrca, OpenHermes, OpenChat-3.5, etc can be used with this endpoint.
+- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only model with [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, ChatML template will be used.
 
     *Options:*
 
index b02c2546eb4c6ab3f4767791e6c747008d7cb2aa..338e60f28d62555dd53fcdb7f416073c0a9e2b45 100644 (file)
@@ -847,9 +847,16 @@ struct server_context {
         slot.sparams.penalize_nl       = json_value(data, "penalize_nl",       default_sparams.penalize_nl);
         slot.params.n_keep             = json_value(data, "n_keep",            slot.params.n_keep);
         slot.params.seed               = json_value(data, "seed",              default_params.seed);
-        if (data.contains("json_schema") && !data.contains("grammar")) {
+        slot.sparams.n_probs           = json_value(data, "n_probs",           default_sparams.n_probs);
+        slot.sparams.min_keep          = json_value(data, "min_keep",          default_sparams.min_keep);
+
+        // process "json_schema" and "grammar"
+        if (data.contains("json_schema") && data.contains("grammar")) {
+            send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
+            return false;
+        } else if (data.contains("json_schema") && !data.contains("grammar")) {
             try {
-                auto schema                = json_value(data, "json_schema",       json::object());
+                auto schema                = json_value(data, "json_schema", json::object());
                 slot.sparams.grammar       = json_schema_to_grammar(schema);
             } catch (const std::exception & e) {
                 send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
@@ -858,8 +865,6 @@ struct server_context {
         } else {
             slot.sparams.grammar       = json_value(data, "grammar",           default_sparams.grammar);
         }
-        slot.sparams.n_probs           = json_value(data, "n_probs",           default_sparams.n_probs);
-        slot.sparams.min_keep          = json_value(data, "min_keep",          default_sparams.min_keep);
 
         if (slot.params.cache_prompt && slot.ga_n != 1) {
             LOG_WARNING("cache_prompt is not supported with group-attention", {});
index 8f20ff61454e9a2892d17776ba2576531e00a335..7d9ab622bb6e5d2641569b183bd27068fe6b43b4 100644 (file)
@@ -352,51 +352,71 @@ static json oaicompat_completion_params_parse(
     // https://platform.openai.com/docs/api-reference/chat/create
     llama_sampling_params default_sparams;
     llama_params["model"]             = json_value(body,   "model",             std::string("unknown"));
-    llama_params["prompt"]            = format_chat(model, chat_template,       body["messages"]);
-    llama_params["cache_prompt"]      = json_value(body,   "cache_prompt",      false);
-    llama_params["temperature"]       = json_value(body,   "temperature",       0.0);
-    llama_params["top_k"]             = json_value(body,   "top_k",             default_sparams.top_k);
-    llama_params["top_p"]             = json_value(body,   "top_p",             1.0);
-    llama_params["n_predict"]         = json_value(body,   "max_tokens",        -1);
-    llama_params["logit_bias"]        = json_value(body,   "logit_bias",        json::object());
     llama_params["frequency_penalty"] = json_value(body,   "frequency_penalty", 0.0);
+    llama_params["logit_bias"]        = json_value(body,   "logit_bias",        json::object());
+    llama_params["n_predict"]         = json_value(body,   "max_tokens",        -1);
     llama_params["presence_penalty"]  = json_value(body,   "presence_penalty",  0.0);
     llama_params["seed"]              = json_value(body,   "seed",              LLAMA_DEFAULT_SEED);
     llama_params["stream"]            = json_value(body,   "stream",            false);
-    llama_params["mirostat"]          = json_value(body,   "mirostat",          default_sparams.mirostat);
-    llama_params["mirostat_tau"]      = json_value(body,   "mirostat_tau",      default_sparams.mirostat_tau);
-    llama_params["mirostat_eta"]      = json_value(body,   "mirostat_eta",      default_sparams.mirostat_eta);
-    llama_params["penalize_nl"]       = json_value(body,   "penalize_nl",       default_sparams.penalize_nl);
-    llama_params["typical_p"]         = json_value(body,   "typical_p",         default_sparams.typical_p);
-    llama_params["repeat_last_n"]     = json_value(body,   "repeat_last_n",     default_sparams.penalty_last_n);
-    llama_params["ignore_eos"]        = json_value(body,   "ignore_eos",        false);
-    llama_params["tfs_z"]             = json_value(body,   "tfs_z",             default_sparams.tfs_z);
-    llama_params["n_keep"]            = json_value(body,   "n_keep",            0);
-
-    if (body.contains("grammar")) {
-        llama_params["grammar"] = json_value(body, "grammar", json::object());
-    }
+    llama_params["temperature"]       = json_value(body,   "temperature",       0.0);
+    llama_params["top_p"]             = json_value(body,   "top_p",             1.0);
 
-    if (body.contains("response_format")) {
-        auto response_format = json_value(body, "response_format", json::object());
-        if (response_format.contains("type")) {
-            if (response_format["type"] == "json_object") {
-                llama_params["json_schema"] = json_value(response_format, "schema", json::object());
-            } else {
-                throw std::runtime_error("response_format type not supported: " + response_format["type"].dump());
-            }
-        }
-    }
+    // Apply chat template to the list of messages
+    llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
 
-    // Handle 'stop' field
+    // Handle "stop" field
     if (body.contains("stop") && body["stop"].is_string()) {
         llama_params["stop"] = json::array({body["stop"].get<std::string>()});
     } else {
         llama_params["stop"] = json_value(body, "stop", json::array());
     }
+    // Some chat templates don't use EOS token to stop generation
+    // We must add their end sequences to list of stop words
+    llama_params["stop"].push_back("<|im_end|>"); // chatml
+    llama_params["stop"].push_back("<end_of_turn>"); // gemma
 
-    // Ensure there is ChatML-specific end sequence among stop words
-    llama_params["stop"].push_back("<|im_end|>");
+    // Handle "response_format" field
+    if (body.contains("response_format")) {
+        json response_format      = json_value(body, "response_format", json::object());
+        std::string response_type = json_value(response_format, "type", std::string());
+        if (response_type == "json_object") {
+            llama_params["json_schema"] = json_value(response_format, "schema", json::object());
+        } else if (!response_type.empty() && response_type != "text") {
+            throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
+        }
+    }
+
+    // Handle "n" field
+    int n_choices = json_value(body, "n", 1);
+    if (n_choices != 1) {
+        throw std::runtime_error("Only one completion choice is allowed");
+    }
+
+    // Handle "logprobs" field
+    // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
+    if (body.contains("logprobs")) {
+        llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
+    } else if (body.contains("top_logprobs")) {
+        throw std::runtime_error("top_logprobs requires logprobs to be set to true");
+    }
+
+    // Params supported by OAI but unsupported by llama.cpp
+    static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
+    for (auto & param : unsupported_params) {
+        if (body.contains(param)) {
+            throw std::runtime_error("Unsupported param: " + param);
+        }
+    }
+
+    // Copy remaining properties to llama_params
+    // This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
+    // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
+    for (const auto & item : body.items()) {
+        // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
+        if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
+            llama_params[item.key()] = item.value();
+        }
+    }
 
     return llama_params;
 }