]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add "tokens" output (#10853)
authorGeorgi Gerganov <redacted>
Wed, 18 Dec 2024 09:05:29 +0000 (11:05 +0200)
committerGitHub <redacted>
Wed, 18 Dec 2024 09:05:29 +0000 (11:05 +0200)
* server : add "tokens" output

ggml-ci

* server : update readme

ggml-ci

* server : return tokens ids only if requested

ggml-ci

* tests : improve "tokens" type check

Co-authored-by: Xuan Son Nguyen <redacted>
* server : remove "tokens" from the OAI endpoint

ggml-ci

---------

Co-authored-by: Xuan Son Nguyen <redacted>
examples/server/README.md
examples/server/server.cpp
examples/server/tests/unit/test_completion.py

index 63a7bf43a920d030ac662adc45af2b0e130a2477..ecd24c899fc86dfaf5932e99eee6e0733fd0afd3 100644 (file)
@@ -438,19 +438,22 @@ These words will not be included in the completion, so make sure to add them to
 
 `cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `true`
 
+`return_tokens`: Return the raw generated token ids in the `tokens` field. Otherwise `tokens` remains empty. Default: `false`
+
 `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values.
 
 `timings_per_token`: Include prompt processing and text generation speed information in each response.  Default: `false`
 
 **Response format**
 
-- Note: In streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
+- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
 
 - `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure:
 
 ```json
 {
-  "content": "<the token selected by the model>",
+  "content": "<the token generated by the model>",
+  "tokens": [ generated token ids if requested ],
   "probs": [
     {
       "prob": float,
@@ -468,6 +471,7 @@ These words will not be included in the completion, so make sure to add them to
 Notice that each `probs` is an array of length `n_probs`.
 
 - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
+- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.
 - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
 - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
 - `model`: The path to the model loaded with `-m`
index 71566b94e61bb050b9ffb0f486f6001bea682599..40aac33f0bf135f2b1871af5cf1a1e9bd812ee7f 100644 (file)
@@ -79,8 +79,9 @@ enum error_type {
 };
 
 struct slot_params {
-    bool stream       = true;
-    bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
+    bool stream        = true;
+    bool cache_prompt  = true; // remember the prompt to avoid reprocessing all prompt
+    bool return_tokens = false;
 
     int32_t n_keep    =  0; // number of tokens to keep from initial prompt
     int32_t n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -199,6 +200,7 @@ struct server_task {
 
         params.stream           = json_value(data, "stream",             false);
         params.cache_prompt     = json_value(data, "cache_prompt",       true);
+        params.return_tokens    = json_value(data, "return_tokens",      false);
         params.n_predict        = json_value(data, "n_predict",          json_value(data, "max_tokens", defaults.n_predict));
         params.n_indent         = json_value(data, "n_indent",           defaults.n_indent);
         params.n_keep           = json_value(data, "n_keep",             defaults.n_keep);
@@ -468,7 +470,10 @@ struct completion_token_output {
 
 struct server_task_result_cmpl_final : server_task_result {
     int index = 0;
-    std::string content;
+
+    std::string  content;
+    llama_tokens tokens;
+
     bool stream;
     result_timings timings;
     std::string prompt;
@@ -510,6 +515,7 @@ struct server_task_result_cmpl_final : server_task_result {
         json res = json {
             {"index",               index},
             {"content",             stream ? "" : content}, // in stream mode, content is already in last partial chunk
+            {"tokens",              stream ? llama_tokens {} : tokens},
             {"id_slot",             id_slot},
             {"stop",                true},
             {"model",               oaicompat_model},
@@ -539,9 +545,9 @@ struct server_task_result_cmpl_final : server_task_result {
         json choices = json::array({json{
             {"finish_reason", finish_reason},
             {"index", 0},
-            {"message", json{
+            {"message", json {
                 {"content", content},
-                {"role", "assistant"}
+                {"role",    "assistant"}
             }
         }}});
 
@@ -605,7 +611,9 @@ struct server_task_result_cmpl_final : server_task_result {
 
 struct server_task_result_cmpl_partial : server_task_result {
     int index = 0;
-    std::string content;
+
+    std::string  content;
+    llama_tokens tokens;
 
     int32_t n_decoded;
     int32_t n_prompt_tokens;
@@ -637,6 +645,7 @@ struct server_task_result_cmpl_partial : server_task_result {
         json res = json {
             {"index",            index},
             {"content",          content},
+            {"tokens",           tokens},
             {"stop",             false},
             {"id_slot",          id_slot},
             {"tokens_predicted", n_decoded},
@@ -678,7 +687,7 @@ struct server_task_result_cmpl_partial : server_task_result {
                 json second_ret = json{
                             {"choices", json::array({json{{"finish_reason", nullptr},
                                                             {"index", 0},
-                                                            {"delta", json{
+                                                            {"delta", json {
                                                             {"content", content}}}
                                                             }})},
                             {"created", t},
@@ -693,7 +702,7 @@ struct server_task_result_cmpl_partial : server_task_result {
                 {"finish_reason", nullptr},
                 {"index", 0},
                 {"delta",
-                json{
+                json {
                     {"content", content},
                 }},
             }});
@@ -955,8 +964,11 @@ struct server_slot {
 
     size_t last_nl_pos = 0;
 
-    std::string generated_text;
+    std::string  generated_text;
+    llama_tokens generated_tokens;
+
     llama_tokens cache_tokens;
+
     std::vector<completion_token_output> generated_token_probs;
 
     bool has_next_token = true;
@@ -1000,6 +1012,7 @@ struct server_slot {
         n_sent_token_probs = 0;
         task_type          = SERVER_TASK_TYPE_COMPLETION;
 
+        generated_tokens.clear();
         generated_token_probs.clear();
     }
 
@@ -1740,8 +1753,10 @@ struct server_context {
         const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
         slot.sampled = result.tok;
 
-        // search stop word and delete it
         slot.generated_text += token_str;
+        if (slot.params.return_tokens) {
+            slot.generated_tokens.push_back(result.tok);
+        }
         slot.has_next_token = true;
 
         // check if there is incomplete UTF-8 character at the end
@@ -1766,6 +1781,7 @@ struct server_context {
             break;
         }
 
+        // search stop word and delete it
         if (!incomplete) {
             size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
 
@@ -1918,6 +1934,7 @@ struct server_context {
         res->id      = slot.id_task;
         res->index   = slot.index;
         res->content = tkn.text_to_send;
+        res->tokens  = { tkn.tok };
 
         res->n_decoded       = slot.n_decoded;
         res->n_prompt_tokens = slot.n_prompt_tokens;
@@ -1958,6 +1975,7 @@ struct server_context {
 
         res->index           = slot.index;
         res->content         = slot.generated_text;
+        res->tokens          = slot.generated_tokens;
         res->timings         = slot.get_timings();
         res->prompt          = common_detokenize(ctx, slot.prompt_tokens, true);
 
index 062ebcd4a05cce2508a42cbdd7b2fe1187ecf261..36aee57dd363866c955a63672f21b1d7e938b262 100644 (file)
@@ -10,16 +10,17 @@ def create_server():
     global server
     server = ServerPreset.tinyllama2()
 
-@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
-    ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
-    ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
+@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
+    ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
+    ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
 ])
-def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
+def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
     global server
     server.start()
     res = server.make_request("POST", "/completion", data={
         "n_predict": n_predict,
         "prompt": prompt,
+        "return_tokens": return_tokens,
     })
     assert res.status_code == 200
     assert res.body["timings"]["prompt_n"] == n_prompt
@@ -27,6 +28,11 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
     assert res.body["truncated"] == truncated
     assert type(res.body["has_new_line"]) == bool
     assert match_regex(re_content, res.body["content"])
+    if return_tokens:
+        assert len(res.body["tokens"]) > 0
+        assert all(type(tok) == int for tok in res.body["tokens"])
+    else:
+        assert res.body["tokens"] == []
 
 
 @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
@@ -56,6 +62,8 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
             assert data["generation_settings"]["seed"] == server.seed
             assert match_regex(re_content, content)
         else:
+            assert len(data["tokens"]) > 0
+            assert all(type(tok) == int for tok in data["tokens"])
             content += data["content"]