]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add option to time limit the generation phase (#9865)
authorGeorgi Gerganov <redacted>
Sat, 12 Oct 2024 13:14:27 +0000 (16:14 +0300)
committerGitHub <redacted>
Sat, 12 Oct 2024 13:14:27 +0000 (16:14 +0300)
ggml-ci

examples/server/README.md
examples/server/server.cpp

index caffbac52730608b5575580f80e02b3e278a0bf1..b5feeb77bd028670da443f9175e51f6511dfa8c9 100644 (file)
@@ -374,6 +374,8 @@ node index.js
 
     `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0`
 
+    `t_max_predict_ms`: Set a time limit in milliseconds for the prediction (a.k.a. text-generation) phase. The timeout will trigger if the generation takes more than the specified time (measured since the first token was generated) and if a new-line character has already been generated. Useful for FIM applications. Default: `0`, which is disabled.
+
     `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
 
     `id_slot`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot.  Default: `-1`
index 0dd2fc8b204cc06a0ae7fbc3bf2c2fe531f8ebec..f809c46d5a3084a0be3bcd52d62353a265a5abb4 100644 (file)
@@ -128,9 +128,12 @@ struct slot_params {
     bool stream       = true;
     bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
 
-    int32_t  n_keep    =  0; // number of tokens to keep from initial prompt
-    int32_t  n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
-    int32_t  n_predict = -1; // new tokens to predict
+    int32_t n_keep    =  0; // number of tokens to keep from initial prompt
+    int32_t n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
+    int32_t n_predict = -1; // new tokens to predict
+
+    int64_t t_max_prompt_ms  = -1; // TODO: implement
+    int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
 
     std::vector<std::string> antiprompt;
 
@@ -175,6 +178,7 @@ struct server_slot {
     server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
 
     bool has_next_token = true;
+    bool has_new_line   = false;
     bool truncated      = false;
     bool stopped_eos    = false;
     bool stopped_word   = false;
@@ -210,6 +214,7 @@ struct server_slot {
 
         n_prompt_tokens    = 0;
         generated_text     = "";
+        has_new_line       = false;
         truncated          = false;
         stopped_eos        = false;
         stopped_word       = false;
@@ -874,6 +879,8 @@ struct server_context {
         slot.sparams.seed              = json_value(data, "seed",              default_sparams.seed);
         slot.sparams.n_probs           = json_value(data, "n_probs",           default_sparams.n_probs);
         slot.sparams.min_keep          = json_value(data, "min_keep",          default_sparams.min_keep);
+      //slot.params.t_max_prompt_ms    = json_value(data, "t_max_prompt_ms",   default_params.t_max_prompt_ms); // TODO: implement
+        slot.params.t_max_predict_ms   = json_value(data, "t_max_predict_ms",  default_params.t_max_predict_ms);
 
         // process "json_schema" and "grammar"
         if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
@@ -1101,6 +1108,20 @@ struct server_context {
             SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
         }
 
+        // if we have already seen a new line, we stop after a certain time limit
+        if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
+            (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
+            slot.stopped_limit  = true;
+            slot.has_next_token = false;
+
+            SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
+        }
+
+        // check if there is a new line in the generated text
+        if (result.text_to_send.find('\n') != std::string::npos) {
+            slot.has_new_line = true;
+        }
+
         // if context shift is disabled, we stop when it reaches the context limit
         if (slot.n_past >= slot.n_ctx) {
             slot.truncated      = true;
@@ -1250,6 +1271,7 @@ struct server_context {
             {"tokens_evaluated",    slot.n_prompt_tokens},
             {"generation_settings", get_formated_generation(slot)},
             {"prompt",              slot.prompt},
+            {"has_new_line",        slot.has_new_line},
             {"truncated",           slot.truncated},
             {"stopped_eos",         slot.stopped_eos},
             {"stopped_word",        slot.stopped_word},
@@ -1576,6 +1598,7 @@ struct server_context {
                         slot_data["prompt"]     = slot.prompt;
                         slot_data["next_token"] = {
                             {"has_next_token", slot.has_next_token},
+                            {"has_new_line",   slot.has_new_line},
                             {"n_remain",       slot.n_remaining},
                             {"n_decoded",      slot.n_decoded},
                             {"stopped_eos",    slot.stopped_eos},
@@ -1914,6 +1937,13 @@ struct server_context {
                                     auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
                                     auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
 
+                                    // for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
+                                    const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
+                                    const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
+
+                                    prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
+                                    suffix_tokens.resize(n_suffix_take);
+
                                     prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
                                     suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
 
@@ -1936,9 +1966,17 @@ struct server_context {
 
                         SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
 
-                        // print prompt tokens:
-                        for (int i = 0; i < (int) prompt_tokens.size(); i++) {
-                            SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+                        // print prompt tokens (for debugging)
+                        if (1) {
+                            // first 16 tokens (avoid flooding logs)
+                            for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
+                                SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+                            }
+                        } else {
+                            // all
+                            for (int i = 0; i < (int) prompt_tokens.size(); i++) {
+                                SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
+                            }
                         }
 
                         // empty prompt passed -> release the slot and send empty response