]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server: add option to output probabilities for completion (#1962)
authorWangHaoranRobin <redacted>
Sun, 2 Jul 2023 21:38:44 +0000 (05:38 +0800)
committerGitHub <redacted>
Sun, 2 Jul 2023 21:38:44 +0000 (00:38 +0300)
* server: add option to output probabilities for completion
* server: fix issue when handling probability output for incomplete tokens for multibyte character generation
* server: fix llama_sample_top_k order
* examples/common.h: put all bool variables in gpt_params together

examples/common.h
examples/server/server.cpp

index 66e5672917996e6cd5f8c9ac69fdaf90e8984a78..96f2228f8677b65e4aa3a39898059045e5993c21 100644 (file)
@@ -31,7 +31,7 @@ struct gpt_params {
     int32_t n_gpu_layers                    = 0;   // number of layers to store in VRAM
     int32_t main_gpu                        = 0;   // the GPU that is used for scratch and small tensors
     float   tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
-    bool    low_vram                        = 0;   // if true, reduce VRAM usage at the cost of performance
+    int32_t n_probs                         = 0;   // if greater than 0, output the probabilities of top n_probs tokens.
 
     // sampling parameters
     std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
@@ -59,6 +59,7 @@ struct gpt_params {
     std::string lora_adapter = "";  // lora adapter path
     std::string lora_base    = "";  // base model path for the lora adapter
 
+    bool low_vram          = false;   // if true, reduce VRAM usage at the cost of performance
     bool memory_f16        = true;  // use f16 instead of f32 for memory kv
     bool random_prompt     = false; // do not randomize prompt if none provided
     bool use_color         = false; // use color to distinguish generations and inputs
index 998d55eacff793d5a839fe5f5b4dc62b90830221..e4ddbe9865506a1a8504fbedfcab44603f49e1e5 100644 (file)
@@ -26,6 +26,17 @@ struct server_params {
     int32_t write_timeout = 600;
 };
 
+// completion token output with probabilities
+struct completion_token_output {
+    struct token_prob {
+        llama_token tok;
+        float prob;
+    };
+
+    std::vector<token_prob> probs;
+    llama_token tok;
+};
+
 static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
     size_t i;
     for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
@@ -86,6 +97,40 @@ static void server_log(const char * level, const char * function, int line,
     fflush(stdout);
 }
 
+// format incomplete utf-8 multibyte character for output
+static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
+    std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
+    // if first bit is 1, meaning it's a partial character
+    if (out.size() > 0 && (out[0] & 0x80) == 0x80) {
+        std::stringstream ss;
+        ss<< std::hex << (out[0] & 0xff);
+        std::string res ( ss.str() );
+        out = "byte: \\x" + res;
+    }
+    return out;
+}
+
+// convert a vector of completion_token_output to json
+static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> probs) {
+    json out = json::array();
+    for (const auto & prob : probs) {
+        json probs_for_token = json::array();
+        for (const auto & p : prob.probs) {
+            std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
+            probs_for_token.push_back(json {
+                { "tok_str", tok_str },
+                { "prob", p.prob },
+            });
+        }
+        std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
+        out.push_back(json {
+            {"content", tok_str},
+            {"probs", probs_for_token},
+        });
+    }
+    return out;
+}
+
 static bool server_verbose = false;
 
 #if SERVER_VERBOSE != 1
@@ -107,6 +152,7 @@ struct llama_server_context {
     bool stream = false;
     bool has_next_token = false;
     std::string generated_text;
+    std::vector<completion_token_output> generated_token_probs;
 
     size_t num_tokens_predicted = 0;
     size_t n_past = 0;
@@ -142,6 +188,7 @@ struct llama_server_context {
         num_tokens_predicted = 0;
         generated_text = "";
         generated_text.reserve(params.n_ctx);
+        generated_token_probs.clear();
         truncated = false;
         stopped_eos = false;
         stopped_word = false;
@@ -221,8 +268,9 @@ struct llama_server_context {
         llama_set_rng_seed(ctx, params.seed);
     }
 
-    llama_token nextToken() {
-        llama_token result = -1;
+    completion_token_output nextToken() {
+        completion_token_output result;
+        result.tok = -1;
 
         if (embd.size() >= (size_t)params.n_ctx) {
             // Reset context
@@ -261,7 +309,8 @@ struct llama_server_context {
 
         if (params.n_predict == 0) {
             has_next_token = false;
-            return llama_token_eos();
+            result.tok = llama_token_eos();
+            return result;
         }
 
         // out of user input, sample next token
@@ -278,7 +327,7 @@ struct llama_server_context {
         const float mirostat_tau = params.mirostat_tau;
         const float mirostat_eta = params.mirostat_eta;
         const bool penalize_nl = params.penalize_nl;
-        llama_token id = 0;
+        const int32_t n_probs = params.n_probs;
 
         {
             auto * logits = llama_get_logits(ctx);
@@ -312,35 +361,42 @@ struct llama_server_context {
 
             if (temp <= 0) {
                 // Greedy sampling
-                id = llama_sample_token_greedy(ctx, &candidates_p);
+                result.tok = llama_sample_token_greedy(ctx, &candidates_p);
+                if (n_probs > 0) {
+                    llama_sample_softmax(ctx, &candidates_p);
+                }
             } else {
                 if (mirostat == 1) {
                     static float mirostat_mu = 2.0f * mirostat_tau;
                     const int mirostat_m = 100;
                     llama_sample_temperature(ctx, &candidates_p, temp);
-                    id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
+                    result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
                 } else if (mirostat == 2) {
                     static float mirostat_mu = 2.0f * mirostat_tau;
                     llama_sample_temperature(ctx, &candidates_p, temp);
-                    id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
+                    result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
                 } else {
                     // Temperature sampling
-                    llama_sample_top_k(ctx, &candidates_p, top_k, 1);
-                    llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
-                    llama_sample_typical(ctx, &candidates_p, typical_p, 1);
-                    llama_sample_top_p(ctx, &candidates_p, top_p, 1);
+                    size_t min_keep = std::max(1, n_probs);
+                    llama_sample_top_k(ctx, &candidates_p, top_k, min_keep);
+                    llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
+                    llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
+                    llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
                     llama_sample_temperature(ctx, &candidates_p, temp);
-                    id = llama_sample_token(ctx, &candidates_p);
+                    result.tok = llama_sample_token(ctx, &candidates_p);
                 }
             }
+
+            for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) {
+                result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
+            }
             last_n_tokens.erase(last_n_tokens.begin());
-            last_n_tokens.push_back(id);
+            last_n_tokens.push_back(result.tok);
             num_tokens_predicted++;
         }
 
         // add it to the context
-        embd.push_back(id);
-        result = id;
+        embd.push_back(result.tok);
         // decrement remaining sampling budget
         --n_remain;
 
@@ -382,12 +438,16 @@ struct llama_server_context {
         return stop_pos;
     }
 
-    std::string doCompletion() {
-        const llama_token token = nextToken();
+    completion_token_output doCompletion() {
+        const completion_token_output token_with_probs = nextToken();
 
-        const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token);
+        const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
         generated_text += token_text;
 
+        if (params.n_probs > 0) {
+            generated_token_probs.push_back(token_with_probs);
+        }
+
         if (multibyte_pending > 0) {
             multibyte_pending -= token_text.size();
         } else if (token_text.size() == 1) {
@@ -416,8 +476,8 @@ struct llama_server_context {
         }
 
         LOG_VERBOSE("next token", {
-            { "token", token },
-            { "token_text", llama_token_to_str(ctx, token) },
+            { "token", token_with_probs.tok },
+            { "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) },
             { "has_next_token", has_next_token },
             { "n_remain", n_remain },
             { "num_tokens_predicted", num_tokens_predicted },
@@ -427,7 +487,7 @@ struct llama_server_context {
             { "stopping_word", stopping_word },
         });
 
-        return token_text;
+        return token_with_probs;
     }
 
     std::vector<float> getEmbedding() {
@@ -669,6 +729,7 @@ static json format_generation_settings(llama_server_context & llama) {
         { "ignore_eos", ignore_eos },
         { "stream", llama.stream },
         { "logit_bias", llama.params.logit_bias },
+        { "n_probs", llama.params.n_probs },
     };
 }
 
@@ -678,8 +739,9 @@ static json format_embedding_response(llama_server_context & llama) {
     };
 }
 
-static json format_final_response(llama_server_context & llama, const std::string & content) {
-    return json {
+static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
+
+    json res = json {
         { "content", content },
         { "stop", true },
         { "model", llama.params.model_alias },
@@ -692,13 +754,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
         { "stopped_limit", llama.stopped_limit },
         { "stopping_word", llama.stopping_word },
     };
+
+    if (llama.params.n_probs > 0) {
+        res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
+    }
+
+    return res;
 }
 
-static json format_partial_response(const std::string & content) {
-    return json {
+static json format_partial_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
+    json res = json {
         { "content", content },
         { "stop", false },
     };
+
+    if (llama.params.n_probs > 0) {
+        res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
+    }
+
+    return res;
 }
 
 static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
@@ -728,6 +802,7 @@ static void parse_options_completion(const json & body, llama_server_context & l
     llama.params.n_keep = body.value("n_keep", default_params.n_keep);
     llama.params.seed = body.value("seed", default_params.seed);
     llama.params.prompt = body.value("prompt", default_params.prompt);
+    llama.params.n_probs = body.value("n_probs", default_params.n_probs);
 
     llama.params.logit_bias.clear();
     if (body.value("ignore_eos", false)) {
@@ -830,7 +905,8 @@ int main(int argc, char ** argv) {
             size_t stop_pos = std::string::npos;
 
             while (llama.has_next_token) {
-                const std::string token_text = llama.doCompletion();
+                const completion_token_output token_with_probs = llama.doCompletion();
+                const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
 
                 stop_pos = llama.findStoppingStrings(llama.generated_text,
                     token_text.size(), STOP_FULL);
@@ -844,7 +920,7 @@ int main(int argc, char ** argv) {
                     llama.generated_text.end());
             }
 
-            const json data = format_final_response(llama, llama.generated_text);
+            const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);
 
             llama_print_timings(llama.ctx);
 
@@ -853,9 +929,11 @@ int main(int argc, char ** argv) {
         } else {
             const auto chunked_content_provider = [&](size_t, DataSink & sink) {
                 size_t sent_count = 0;
+                size_t sent_token_probs_index = 0;
 
                 while (llama.has_next_token) {
-                    const std::string token_text = llama.doCompletion();
+                    const completion_token_output token_with_probs = llama.doCompletion();
+                    const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
                     if (llama.multibyte_pending > 0) {
                         continue;
                     }
@@ -878,10 +956,22 @@ int main(int argc, char ** argv) {
                     const std::string to_send = llama.generated_text.substr(pos, stop_pos);
                     sent_count += to_send.size();
 
+                    std::vector<completion_token_output> probs_output = {};
+
+                    if (llama.params.n_probs > 0) {
+                        const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
+                        size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
+                        size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
+                        if (probs_pos < probs_stop_pos) {
+                            probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
+                        }
+                        sent_token_probs_index = probs_stop_pos;
+                    }
+
                     const json data = llama.has_next_token
-                                          ? format_partial_response(to_send)
+                                          ? format_partial_response(llama, to_send, probs_output)
                                           // Generation is done, send extra information.
-                                          : format_final_response(llama, to_send);
+                                          : format_final_response(llama, to_send, llama.generated_token_probs);
 
                     const std::string str =
                         "data: " +