]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix logprobs, make it OAI-compatible (#10783)
authorXuan Son Nguyen <redacted>
Thu, 19 Dec 2024 14:40:08 +0000 (15:40 +0100)
committerGitHub <redacted>
Thu, 19 Dec 2024 14:40:08 +0000 (15:40 +0100)
* server : fix logprobs, make it openai-compatible

* update docs

* add std::log

* return pre-sampling p

* sort before apply softmax

* add comment

* fix test

* set p for sampled token

* update docs

* add --multi-token-probs

* update docs

* add `post_sampling_probs` option

* update docs [no ci]

* remove --multi-token-probs

* "top_probs" with "post_sampling_probs"

* resolve review comments

* rename struct token_prob to prob_info

* correct comment placement

* fix setting prob for sampled token

examples/server/README.md
examples/server/server.cpp
examples/server/tests/unit/test_chat_completion.py
examples/server/tests/unit/test_completion.py
examples/server/tests/unit/test_embedding.py
examples/server/utils.hpp

index d006a8d37cf6fe02cf9c594d83f591ddd906e595..6d64656926250872ec62b3f15b045190d735713a 100644 (file)
@@ -343,6 +343,10 @@ node index.js
 
 ### POST `/completion`: Given a `prompt`, it returns the predicted completion.
 
+> [!IMPORTANT]
+>
+> This endpoint is **not** OAI-compatible
+
 *Options:*
 
 `prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, if `cache_prompt` is `true`, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. A `BOS` token is inserted at the start, if all of the following conditions are true:
@@ -444,38 +448,68 @@ These words will not be included in the completion, so make sure to add them to
 
 `timings_per_token`: Include prompt processing and text generation speed information in each response.  Default: `false`
 
+`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
+
 **Response format**
 
 - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
 
-- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure:
-
-```json
-{
-  "content": "<the token generated by the model>",
-  "tokens": [ generated token ids if requested ],
-  "probs": [
-    {
-      "prob": float,
-      "tok_str": "<most likely token>"
-    },
-    {
-      "prob": float,
-      "tok_str": "<second most likely token>"
-    },
+- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:
+  ```json
+  {
+    "content": "<the generated completion text>",
+    "tokens": [ generated token ids if requested ],
     ...
-  ]
-},
-```
-
-Notice that each `probs` is an array of length `n_probs`.
+    "probs": [
+      {
+        "id": <token id>,
+        "logprob": float,
+        "token": "<most likely token>",
+        "bytes": [int, int, ...],
+        "top_logprobs": [
+          {
+            "id": <token id>,
+            "logprob": float,
+            "token": "<token text>",
+            "bytes": [int, int, ...],
+          },
+          {
+            "id": <token id>,
+            "logprob": float,
+            "token": "<token text>",
+            "bytes": [int, int, ...],
+          },
+          ...
+        ]
+      },
+      {
+        "id": <token id>,
+        "logprob": float,
+        "token": "<most likely token>",
+        "bytes": [int, int, ...],
+        "top_logprobs": [
+          ...
+        ]
+      },
+      ...
+    ]
+  },
+  ```
+  Please note that if `post_sampling_probs` is set to `true`:
+    - `logprob` will be replaced with `prob`, with the value between 0.0 and 1.0
+    - `top_logprobs` will be replaced with `top_probs`. Each element contains:
+      - `id`: token ID
+      - `token`: token in string
+      - `bytes`: token in bytes
+      - `prob`: token probability, with the value between 0.0 and 1.0
+    - Number of elements in `top_probs` may be less than `n_probs`
 
 - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
 - `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.
 - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
 - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
-- `model`: The path to the model loaded with `-m`
-- `prompt`: The provided `prompt`
+- `model`: The model alias (for model path, please use `/props` endpoint)
+- `prompt`: The processed `prompt` (special tokens may be added)
 - `stop_type`: Indicating whether the completion has stopped. Possible values are:
   - `none`: Generating (not stopped)
   - `eos`: Stopped because it encountered the EOS token
index 5ed4e8d2744285c39272a2c7ec6621470e67c986..fa3682a920649ce980503aa0c3da6aefac6e489d 100644 (file)
@@ -93,6 +93,7 @@ struct slot_params {
 
     std::vector<std::string> antiprompt;
     bool timings_per_token = false;
+    bool post_sampling_probs = false;
     bool ignore_eos = false;
 
     struct common_params_sampling sampling;
@@ -151,6 +152,7 @@ struct slot_params {
             {"speculative.n_min",         speculative.n_min},
             {"speculative.p_min",         speculative.p_min},
             {"timings_per_token",         timings_per_token},
+            {"post_sampling_probs",       post_sampling_probs},
         };
     }
 };
@@ -231,6 +233,7 @@ struct server_task {
         params.sampling.seed               = json_value(data, "seed",               defaults.sampling.seed);
         params.sampling.n_probs            = json_value(data, "n_probs",            defaults.sampling.n_probs);
         params.sampling.min_keep           = json_value(data, "min_keep",           defaults.sampling.min_keep);
+        params.post_sampling_probs         = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
 
         params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
         params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
@@ -436,36 +439,67 @@ inline std::string stop_type_to_str(stop_type type) {
 
 struct completion_token_output {
     llama_token tok;
+    float prob;
     std::string text_to_send;
-    struct token_prob {
+    struct prob_info {
         llama_token tok;
-        std::string tok_str;
+        std::string txt;
         float prob;
     };
-    std::vector<token_prob> probs;
+    std::vector<prob_info> probs;
 
-    json to_json() const {
+    json to_json(bool post_sampling_probs) const {
         json probs_for_token = json::array();
         for (const auto & p : probs) {
+            std::string txt(p.txt);
+            txt.resize(validate_utf8(txt));
             probs_for_token.push_back(json {
-                {"tok_str", p.tok_str},
-                {"prob",    p.prob},
+                {"id",      p.tok},
+                {"token",   txt},
+                {"bytes",   str_to_bytes(p.txt)},
+                {
+                    post_sampling_probs ? "prob" : "logprob",
+                    post_sampling_probs ? p.prob : logarithm(p.prob)
+                },
             });
         }
         return probs_for_token;
     }
 
-    static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
+    static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
         json out = json::array();
-        for (const auto & prob : probs) {
-            const std::string tok_str = prob.text_to_send;
+        for (const auto & p : probs) {
+            std::string txt(p.text_to_send);
+            txt.resize(validate_utf8(txt));
             out.push_back(json {
-                {"content", tok_str},
-                {"probs",   prob.to_json()},
+                {"id",           p.tok},
+                {"token",        txt},
+                {"bytes",        str_to_bytes(p.text_to_send)},
+                {
+                    post_sampling_probs ? "prob" : "logprob",
+                    post_sampling_probs ? p.prob : logarithm(p.prob)
+                },
+                {
+                    post_sampling_probs ? "top_probs" : "top_logprobs",
+                    p.to_json(post_sampling_probs)
+                },
             });
         }
         return out;
     }
+
+    static float logarithm(float x) {
+        // nlohmann::json converts -inf to null, so we need to prevent that
+        return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
+    }
+
+    static std::vector<unsigned char> str_to_bytes(const std::string & str) {
+        std::vector<unsigned char> bytes;
+        for (unsigned char c : str) {
+            bytes.push_back(c);
+        }
+        return bytes;
+    }
 };
 
 struct server_task_result_cmpl_final : server_task_result {
@@ -486,6 +520,7 @@ struct server_task_result_cmpl_final : server_task_result {
     std::string stopping_word;
     stop_type stop = STOP_TYPE_NONE;
 
+    bool post_sampling_probs;
     std::vector<completion_token_output> probs_output;
 
     slot_params generation_params;
@@ -530,8 +565,8 @@ struct server_task_result_cmpl_final : server_task_result {
             {"tokens_cached",       n_tokens_cached},
             {"timings",             timings.to_json()},
         };
-        if (!probs_output.empty()) {
-            res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
+        if (!stream && !probs_output.empty()) {
+            res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
         }
         return res;
     }
@@ -542,19 +577,25 @@ struct server_task_result_cmpl_final : server_task_result {
             finish_reason = "stop";
         }
 
-        json choices = json::array({json{
+        json choice = json{
             {"finish_reason", finish_reason},
             {"index", 0},
             {"message", json {
                 {"content", content},
                 {"role",    "assistant"}
             }
-        }}});
+        }};
+
+        if (!stream && probs_output.size() > 0) {
+            choice["logprobs"] = json{
+                {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
+            };
+        }
 
         std::time_t t = std::time(0);
 
         json res = json {
-            {"choices", choices},
+            {"choices", json::array({choice})},
             {"created", t},
             {"model", oaicompat_model},
             {"object", "chat.completion"},
@@ -584,12 +625,14 @@ struct server_task_result_cmpl_final : server_task_result {
             finish_reason = "stop";
         }
 
-        json choices = json::array({json{{"finish_reason", finish_reason},
-                                        {"index", 0},
-                                        {"delta", json::object()}}});
+        json choice = json{
+            {"finish_reason", finish_reason},
+            {"index", 0},
+            {"delta", json::object()}
+        };
 
         json ret = json {
-            {"choices", choices},
+            {"choices", json::array({choice})},
             {"created", t},
             {"id",      oaicompat_cmpl_id},
             {"model",   oaicompat_model},
@@ -618,7 +661,8 @@ struct server_task_result_cmpl_partial : server_task_result {
     int32_t n_decoded;
     int32_t n_prompt_tokens;
 
-    std::vector<completion_token_output> probs_output;
+    bool post_sampling_probs;
+    completion_token_output prob_output;
     result_timings timings;
 
     // OAI-compat fields
@@ -655,8 +699,8 @@ struct server_task_result_cmpl_partial : server_task_result {
         if (timings.prompt_n > 0) {
             res.push_back({"timings", timings.to_json()});
         }
-        if (!probs_output.empty()) {
-            res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
+        if (!prob_output.probs.empty()) {
+            res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
         }
         return res;
     }
@@ -708,6 +752,14 @@ struct server_task_result_cmpl_partial : server_task_result {
             }});
         }
 
+        GGML_ASSERT(choices.size() >= 1);
+
+        if (prob_output.probs.size() > 0) {
+            choices[0]["logprobs"] = json{
+                {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
+            };
+        }
+
         json ret = json {
             {"choices", choices},
             {"created", t},
@@ -1001,7 +1053,6 @@ struct server_slot {
 
     // stats
     size_t n_sent_text        = 0; // number of sent text character
-    size_t n_sent_token_probs = 0;
 
     int64_t t_start_process_prompt;
     int64_t t_start_generation;
@@ -1023,7 +1074,6 @@ struct server_slot {
         stopping_word      = "";
         n_past             = 0;
         n_sent_text        = 0;
-        n_sent_token_probs = 0;
         task_type          = SERVER_TASK_TYPE_COMPLETION;
 
         generated_tokens.clear();
@@ -1764,7 +1814,7 @@ struct server_context {
 
     bool process_token(completion_token_output & result, server_slot & slot) {
         // remember which tokens were sampled - used for repetition penalties during sampling
-        const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
+        const std::string token_str = result.text_to_send;
         slot.sampled = result.tok;
 
         slot.generated_text += token_str;
@@ -1774,26 +1824,7 @@ struct server_context {
         slot.has_next_token = true;
 
         // check if there is incomplete UTF-8 character at the end
-        bool incomplete = false;
-        for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
-            unsigned char c = slot.generated_text[slot.generated_text.size() - i];
-            if ((c & 0xC0) == 0x80) {
-                // continuation byte: 10xxxxxx
-                continue;
-            }
-            if ((c & 0xE0) == 0xC0) {
-                // 2-byte character: 110xxxxx ...
-                incomplete = i < 2;
-            } else if ((c & 0xF0) == 0xE0) {
-                // 3-byte character: 1110xxxx ...
-                incomplete = i < 3;
-            } else if ((c & 0xF8) == 0xF0) {
-                // 4-byte character: 11110xxx ...
-                incomplete = i < 4;
-            }
-            // else 1-byte character or invalid byte
-            break;
-        }
+        bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
 
         // search stop word and delete it
         if (!incomplete) {
@@ -1923,6 +1954,55 @@ struct server_context {
         return slot.has_next_token; // continue
     }
 
+    void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
+        size_t n_probs = slot.params.sampling.n_probs;
+        size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
+        if (post_sampling) {
+            const auto * cur_p = common_sampler_get_candidates(slot.smpl);
+            const size_t max_probs = cur_p->size;
+
+            // set probability for sampled token
+            for (size_t i = 0; i < max_probs; i++) {
+                if (cur_p->data[i].id == result.tok) {
+                    result.prob = cur_p->data[i].p;
+                    break;
+                }
+            }
+
+            // set probability for top n_probs tokens
+            result.probs.reserve(max_probs);
+            for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
+                result.probs.push_back({
+                    cur_p->data[i].id,
+                    common_detokenize(ctx, {cur_p->data[i].id}, special),
+                    cur_p->data[i].p
+                });
+            }
+        } else {
+            // TODO: optimize this with min-p optimization
+            std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
+
+            // set probability for sampled token
+            for (size_t i = 0; i < n_vocab; i++) {
+                // set probability for sampled token
+                if (cur[i].id == result.tok) {
+                    result.prob = cur[i].p;
+                    break;
+                }
+            }
+
+            // set probability for top n_probs tokens
+            result.probs.reserve(n_probs);
+            for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
+                result.probs.push_back({
+                    cur[i].id,
+                    common_detokenize(ctx, {cur[i].id}, special),
+                    cur[i].p
+                });
+            }
+        }
+    }
+
     void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
         send_error(task.id, error, type);
     }
@@ -1950,8 +2030,9 @@ struct server_context {
         res->content = tkn.text_to_send;
         res->tokens  = { tkn.tok };
 
-        res->n_decoded       = slot.n_decoded;
-        res->n_prompt_tokens = slot.n_prompt_tokens;
+        res->n_decoded           = slot.n_decoded;
+        res->n_prompt_tokens     = slot.n_prompt_tokens;
+        res->post_sampling_probs = slot.params.post_sampling_probs;
 
         res->verbose           = slot.params.verbose;
         res->oaicompat         = slot.params.oaicompat;
@@ -1961,17 +2042,7 @@ struct server_context {
 
         // populate res.probs_output
         if (slot.params.sampling.n_probs > 0) {
-            const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
-
-            const size_t probs_pos      = std::min(slot.n_sent_token_probs,                       slot.generated_token_probs.size());
-            const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
-
-            std::vector<completion_token_output> probs_output;
-            if (probs_pos < probs_stop_pos) {
-                res->probs_output = std::vector<completion_token_output>(
-                        slot.generated_token_probs.begin() + probs_pos,
-                        slot.generated_token_probs.begin() + probs_stop_pos);
-            }
+            res->prob_output = tkn; // copy the token probs
         }
 
         // populate timings if this is final response or timings_per_token is enabled
@@ -1993,13 +2064,14 @@ struct server_context {
         res->timings         = slot.get_timings();
         res->prompt          = common_detokenize(ctx, slot.prompt_tokens, true);
 
-        res->truncated       = slot.truncated;
-        res->n_decoded       = slot.n_decoded;
-        res->n_prompt_tokens = slot.n_prompt_tokens;
-        res->n_tokens_cached = slot.n_past;
-        res->has_new_line    = slot.has_new_line;
-        res->stopping_word   = slot.stopping_word;
-        res->stop            = slot.stop;
+        res->truncated           = slot.truncated;
+        res->n_decoded           = slot.n_decoded;
+        res->n_prompt_tokens     = slot.n_prompt_tokens;
+        res->n_tokens_cached     = slot.n_past;
+        res->has_new_line        = slot.has_new_line;
+        res->stopping_word       = slot.stopping_word;
+        res->stop                = slot.stop;
+        res->post_sampling_probs = slot.params.post_sampling_probs;
 
         res->verbose           = slot.params.verbose;
         res->stream            = slot.params.stream;
@@ -2796,7 +2868,9 @@ struct server_context {
                     continue; // continue loop of slots
                 }
 
-                llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
+                const int tok_idx = slot.i_batch - i;
+
+                llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
 
                 slot.i_batch = -1;
 
@@ -2815,17 +2889,12 @@ struct server_context {
                 slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
 
                 completion_token_output result;
-                result.tok = id;
+                result.tok          = id;
+                result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
+                result.prob         = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
 
-                const auto * cur_p = common_sampler_get_candidates(slot.smpl);
-
-                for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
-                    auto tok_id = cur_p->data[i].id;
-                    result.probs.push_back({
-                        tok_id,
-                        tokens_to_output_formatted_string(ctx, tok_id),
-                        i >= cur_p->size ? 0.0f : cur_p->data[i].p,
-                    });
+                if (slot.params.sampling.n_probs > 0) {
+                    populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
                 }
 
                 if (!process_token(result, slot)) {
@@ -2909,7 +2978,11 @@ struct server_context {
                 for (size_t i = 0; i < ids.size(); ++i) {
                     completion_token_output result;
 
-                    result.tok = ids[i];
+                    result.tok          = ids[i];
+                    result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
+                    result.prob         = 1.0f; // set later
+
+                    // TODO: set result.probs
 
                     if (!process_token(result, slot)) {
                         // release slot because of stop condition
index 6573cc17f7b87843bdc389ee2e0f95065fcf1e43..0fa1a17c1f50a1ed48b0903662abc0ef672ffd31 100644 (file)
@@ -92,7 +92,6 @@ def test_chat_completion_with_openai_library():
         seed=42,
         temperature=0.8,
     )
-    print(res)
     assert res.choices[0].finish_reason == "length"
     assert res.choices[0].message.content is not None
     assert match_regex("(Suddenly)+", res.choices[0].message.content)
@@ -163,3 +162,64 @@ def test_chat_completion_with_timings_per_token():
         assert "predicted_per_second" in data["timings"]
         assert "predicted_n" in data["timings"]
         assert data["timings"]["predicted_n"] <= 10
+
+
+def test_logprobs():
+    global server
+    server.start()
+    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
+    res = client.chat.completions.create(
+        model="gpt-3.5-turbo-instruct",
+        temperature=0.0,
+        messages=[
+            {"role": "system", "content": "Book"},
+            {"role": "user", "content": "What is the best book"},
+        ],
+        max_tokens=5,
+        logprobs=True,
+        top_logprobs=10,
+    )
+    output_text = res.choices[0].message.content
+    aggregated_text = ''
+    assert res.choices[0].logprobs is not None
+    assert res.choices[0].logprobs.content is not None
+    for token in res.choices[0].logprobs.content:
+        aggregated_text += token.token
+        assert token.logprob <= 0.0
+        assert token.bytes is not None
+        assert len(token.top_logprobs) > 0
+    assert aggregated_text == output_text
+
+
+def test_logprobs_stream():
+    global server
+    server.start()
+    client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
+    res = client.chat.completions.create(
+        model="gpt-3.5-turbo-instruct",
+        temperature=0.0,
+        messages=[
+            {"role": "system", "content": "Book"},
+            {"role": "user", "content": "What is the best book"},
+        ],
+        max_tokens=5,
+        logprobs=True,
+        top_logprobs=10,
+        stream=True,
+    )
+    output_text = ''
+    aggregated_text = ''
+    for data in res:
+        choice = data.choices[0]
+        if choice.finish_reason is None:
+            if choice.delta.content:
+                output_text += choice.delta.content
+            assert choice.logprobs is not None
+            assert choice.logprobs.content is not None
+            for token in choice.logprobs.content:
+                aggregated_text += token.token
+                assert token.logprob <= 0.0
+                assert token.bytes is not None
+                assert token.top_logprobs is not None
+                assert len(token.top_logprobs) > 0
+    assert aggregated_text == output_text
index 36aee57dd363866c955a63672f21b1d7e938b262..b88d45f18547ff8240311f165b79f5713af0b662 100644 (file)
@@ -270,9 +270,68 @@ def test_n_probs():
     assert "completion_probabilities" in res.body
     assert len(res.body["completion_probabilities"]) == 5
     for tok in res.body["completion_probabilities"]:
-        assert "probs" in tok
-        assert len(tok["probs"]) == 10
-        for prob in tok["probs"]:
-            assert "prob" in prob
-            assert "tok_str" in prob
-            assert 0.0 <= prob["prob"] <= 1.0
+        assert "id" in tok and tok["id"] > 0
+        assert "token" in tok and type(tok["token"]) == str
+        assert "logprob" in tok and tok["logprob"] <= 0.0
+        assert "bytes" in tok and type(tok["bytes"]) == list
+        assert len(tok["top_logprobs"]) == 10
+        for prob in tok["top_logprobs"]:
+            assert "id" in prob and prob["id"] > 0
+            assert "token" in prob and type(prob["token"]) == str
+            assert "logprob" in prob and prob["logprob"] <= 0.0
+            assert "bytes" in prob and type(prob["bytes"]) == list
+
+
+def test_n_probs_stream():
+    global server
+    server.start()
+    res = server.make_stream_request("POST", "/completion", data={
+        "prompt": "I believe the meaning of life is",
+        "n_probs": 10,
+        "temperature": 0.0,
+        "n_predict": 5,
+        "stream": True,
+    })
+    for data in res:
+        if data["stop"] == False:
+            assert "completion_probabilities" in data
+            assert len(data["completion_probabilities"]) == 1
+            for tok in data["completion_probabilities"]:
+                assert "id" in tok and tok["id"] > 0
+                assert "token" in tok and type(tok["token"]) == str
+                assert "logprob" in tok and tok["logprob"] <= 0.0
+                assert "bytes" in tok and type(tok["bytes"]) == list
+                assert len(tok["top_logprobs"]) == 10
+                for prob in tok["top_logprobs"]:
+                    assert "id" in prob and prob["id"] > 0
+                    assert "token" in prob and type(prob["token"]) == str
+                    assert "logprob" in prob and prob["logprob"] <= 0.0
+                    assert "bytes" in prob and type(prob["bytes"]) == list
+
+
+def test_n_probs_post_sampling():
+    global server
+    server.start()
+    res = server.make_request("POST", "/completion", data={
+        "prompt": "I believe the meaning of life is",
+        "n_probs": 10,
+        "temperature": 0.0,
+        "n_predict": 5,
+        "post_sampling_probs": True,
+    })
+    assert res.status_code == 200
+    assert "completion_probabilities" in res.body
+    assert len(res.body["completion_probabilities"]) == 5
+    for tok in res.body["completion_probabilities"]:
+        assert "id" in tok and tok["id"] > 0
+        assert "token" in tok and type(tok["token"]) == str
+        assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
+        assert "bytes" in tok and type(tok["bytes"]) == list
+        assert len(tok["top_probs"]) == 10
+        for prob in tok["top_probs"]:
+            assert "id" in prob and prob["id"] > 0
+            assert "token" in prob and type(prob["token"]) == str
+            assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
+            assert "bytes" in prob and type(prob["bytes"]) == list
+        # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
+        assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
index e32d7458296051fe4be0ade439167232f82e10a1..43e372fc70d71a3c23d8bc99eb9b4a3ea192fc73 100644 (file)
@@ -50,6 +50,8 @@ def test_embedding_multiple():
 @pytest.mark.parametrize(
     "input,is_multi_prompt",
     [
+        # do not crash on empty input
+        ("", False),
         # single prompt
         ("string", False),
         ([12, 34, 56], False),
@@ -103,6 +105,7 @@ def test_embedding_pooling_none_oai():
 
     # /v1/embeddings does not support pooling type 'none'
     assert res.status_code == 400
+    assert "error" in res.body
 
 
 def test_embedding_openai_library_single():
index ffdffe904308f98d11024f60a3f78376de55fdf4..94bb285b6f2d14847f4b36847223e6e712a2d14e 100644 (file)
@@ -171,6 +171,36 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
     return result;
 }
 
+// return the last index of character that can form a valid string
+// if the last character is potentially cut in half, return the index before the cut
+// if validate_utf8(text) == text.size(), then the whole text is valid utf8
+static size_t validate_utf8(const std::string& text) {
+    size_t len = text.size();
+    if (len == 0) return 0;
+
+    // Check the last few bytes to see if a multi-byte character is cut off
+    for (size_t i = 1; i <= 4 && i <= len; ++i) {
+        unsigned char c = text[len - i];
+        // Check for start of a multi-byte sequence from the end
+        if ((c & 0xE0) == 0xC0) {
+            // 2-byte character start: 110xxxxx
+            // Needs at least 2 bytes
+            if (i < 2) return len - i;
+        } else if ((c & 0xF0) == 0xE0) {
+            // 3-byte character start: 1110xxxx
+            // Needs at least 3 bytes
+            if (i < 3) return len - i;
+        } else if ((c & 0xF8) == 0xF0) {
+            // 4-byte character start: 11110xxx
+            // Needs at least 4 bytes
+            if (i < 4) return len - i;
+        }
+    }
+
+    // If no cut-off multi-byte character is found, return full length
+    return len;
+}
+
 //
 // template utils
 //
@@ -671,3 +701,33 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
 static std::string safe_json_to_str(json data) {
     return data.dump(-1, ' ', false, json::error_handler_t::replace);
 }
+
+static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
+    std::vector<llama_token_data> cur;
+    const auto * logits = llama_get_logits_ith(ctx, idx);
+    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+
+    cur.resize(n_vocab);
+    for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
+    }
+
+    // sort tokens by logits
+    std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
+        return a.logit > b.logit;
+    });
+
+    // apply softmax
+    float max_l = cur[0].logit;
+    float cum_sum = 0.0f;
+    for (size_t i = 0; i < cur.size(); ++i) {
+        float p = expf(cur[i].logit - max_l);
+        cur[i].p = p;
+        cum_sum += p;
+    }
+    for (size_t i = 0; i < cur.size(); ++i) {
+        cur[i].p /= cum_sum;
+    }
+
+    return cur;
+}