]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: allow filtering llama server response fields (#10940)
authorNeverLucky <redacted>
Tue, 24 Dec 2024 16:39:49 +0000 (19:39 +0300)
committerGitHub <redacted>
Tue, 24 Dec 2024 16:39:49 +0000 (17:39 +0100)
* llama_server_response_fields

* llama_server_response_fields_fix_issues

* params fixes

* fix

* clarify docs

* change to "response_fields"

---------

Co-authored-by: Xuan Son Nguyen <redacted>
examples/server/README.md
examples/server/server.cpp
examples/server/tests/unit/test_completion.py
examples/server/utils.hpp

index 5e3d6a6e643a6eb995255c757a8e562d9028faad..c7d91be9976c4924abdac273971bbc0043497442 100644 (file)
@@ -450,6 +450,8 @@ These words will not be included in the completion, so make sure to add them to
 
 `post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
 
+`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error.
+
 **Response format**
 
 - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
index 476a9225f75b781360dca73dc71ad1851fb53fab..3fbfb13c49b729f7a673d4c2d6ae4e1aef3133f1 100644 (file)
@@ -92,6 +92,7 @@ struct slot_params {
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
 
     std::vector<std::string> antiprompt;
+    std::vector<std::string> response_fields;
     bool timings_per_token = false;
     bool post_sampling_probs = false;
     bool ignore_eos = false;
@@ -209,6 +210,7 @@ struct server_task {
         params.n_discard        = json_value(data, "n_discard",          defaults.n_discard);
       //params.t_max_prompt_ms  = json_value(data, "t_max_prompt_ms",    defaults.t_max_prompt_ms); // TODO: implement
         params.t_max_predict_ms = json_value(data, "t_max_predict_ms",   defaults.t_max_predict_ms);
+        params.response_fields  = json_value(data, "response_fields",   std::vector<std::string>());
 
         params.sampling.top_k              = json_value(data, "top_k",              defaults.sampling.top_k);
         params.sampling.top_p              = json_value(data, "top_p",              defaults.sampling.top_p);
@@ -522,6 +524,7 @@ struct server_task_result_cmpl_final : server_task_result {
 
     bool post_sampling_probs;
     std::vector<completion_token_output> probs_output;
+    std::vector<std::string>  response_fields;
 
     slot_params generation_params;
 
@@ -568,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result {
         if (!stream && !probs_output.empty()) {
             res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
         }
-        return res;
+        return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
     }
 
     json to_json_oaicompat_chat() {
@@ -2066,6 +2069,7 @@ struct server_context {
         res->tokens          = slot.generated_tokens;
         res->timings         = slot.get_timings();
         res->prompt          = common_detokenize(ctx, slot.prompt_tokens, true);
+        res->response_fields = slot.params.response_fields;
 
         res->truncated           = slot.truncated;
         res->n_decoded           = slot.n_decoded;
index b88d45f18547ff8240311f165b79f5713af0b662..00d5ce391d8f077ecf325fac3a60518d0c364837 100644 (file)
@@ -257,6 +257,40 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
         # assert match_regex(re_content, res.body["content"])
 
 
+@pytest.mark.parametrize(
+    "prompt,n_predict,response_fields",
+    [
+        ("I believe the meaning of life is", 8, []),
+        ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
+    ],
+)
+def test_completion_response_fields(
+    prompt: str, n_predict: int, response_fields: list[str]
+):
+    global server
+    server.start()
+    res = server.make_request(
+        "POST",
+        "/completion",
+        data={
+            "n_predict": n_predict,
+            "prompt": prompt,
+            "response_fields": response_fields,
+        },
+    )
+    assert res.status_code == 200
+    assert "content" in res.body
+    assert len(res.body["content"])
+    if len(response_fields):
+        assert res.body["generation_settings/n_predict"] == n_predict
+        assert res.body["prompt"] == "<s> " + prompt
+        assert isinstance(res.body["content"], str)
+        assert len(res.body) == len(response_fields)
+    else:
+        assert len(res.body)
+        assert "generation_settings" in res.body
+
+
 def test_n_probs():
     global server
     server.start()
index 1987acac89159e461872489ab52e1b5df61f0618..043d8b52897db01b6d5c596214dcbc1e0a28bf3e 100644 (file)
@@ -90,6 +90,28 @@ static bool json_is_array_of_mixed_numbers_strings(const json & data) {
     return false;
 }
 
+// get value by path(key1 / key2)
+static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
+    json result = json::object();
+
+    for (const std::string & path : paths) {
+        json current = js;
+        const auto keys = string_split<std::string>(path, /*separator*/ '/');
+        bool valid_path = true;
+        for (const std::string & k : keys) {
+            if (valid_path && current.is_object() && current.contains(k)) {
+                current = current[k];
+            } else {
+                valid_path = false;
+            }
+        }
+        if (valid_path) {
+            result[path] = current;
+        }
+    }
+    return result;
+}
+
 /**
  * this handles 2 cases:
  * - only string, example: "string"