]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : add n_indent parameter for line indentation requirement (#9929)
authorGeorgi Gerganov <redacted>
Fri, 18 Oct 2024 04:32:19 +0000 (07:32 +0300)
committerGitHub <redacted>
Fri, 18 Oct 2024 04:32:19 +0000 (07:32 +0300)
ggml-ci

examples/server/README.md
examples/server/server.cpp

index fcdb02afd3b93f60b447195cdb3cabcaf38007c3..09f1aa249ab1fc84d9786eb47a67ed298503e57c 100644 (file)
@@ -333,6 +333,8 @@ node index.js
 
     `n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity.
 
+    `n_indent`: Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks. Default: `0`
+
     `n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
     By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.
 
index b5e63384ca4f81ce9812c36b1d9ebfb6d1b38e04..8fd44387846629f26de54d62ffdca58a8de6e2c5 100644 (file)
@@ -131,6 +131,7 @@ struct slot_params {
     int32_t n_keep    =  0; // number of tokens to keep from initial prompt
     int32_t n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
     int32_t n_predict = -1; // new tokens to predict
+    int32_t n_indent  =  0; // mininum line indentation for the generated text in number of whitespace characters
 
     int64_t t_max_prompt_ms  = -1; // TODO: implement
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -173,6 +174,8 @@ struct server_slot {
     std::vector<llama_token> prompt_tokens;
     std::vector<llama_token> extra_tokens;
 
+    size_t last_nl_pos = 0;
+
     std::string generated_text;
     std::vector<llama_token> cache_tokens;
     std::vector<completion_token_output> generated_token_probs;
@@ -215,6 +218,7 @@ struct server_slot {
         SLT_DBG(*this, "%s", "\n");
 
         n_prompt_tokens    = 0;
+        last_nl_pos        = 0;
         generated_text     = "";
         has_new_line       = false;
         truncated          = false;
@@ -860,6 +864,7 @@ struct server_context {
         slot.params.stream             = json_value(data, "stream",            false);
         slot.params.cache_prompt       = json_value(data, "cache_prompt",      false);
         slot.params.n_predict          = json_value(data, "n_predict",         json_value(data, "max_tokens", default_params.n_predict));
+        slot.params.n_indent           = json_value(data, "n_indent",          default_params.n_indent);
         slot.sparams.top_k             = json_value(data, "top_k",             default_sparams.top_k);
         slot.sparams.top_p             = json_value(data, "top_p",             default_sparams.top_p);
         slot.sparams.min_p             = json_value(data, "min_p",             default_sparams.min_p);
@@ -878,7 +883,7 @@ struct server_context {
         slot.sparams.mirostat_tau      = json_value(data, "mirostat_tau",      default_sparams.mirostat_tau);
         slot.sparams.mirostat_eta      = json_value(data, "mirostat_eta",      default_sparams.mirostat_eta);
         slot.sparams.penalize_nl       = json_value(data, "penalize_nl",       default_sparams.penalize_nl);
-        slot.params.n_keep             = json_value(data, "n_keep",            slot.params.n_keep);
+        slot.params.n_keep             = json_value(data, "n_keep",            default_params.n_keep);
         slot.params.n_discard          = json_value(data, "n_discard",         default_params.n_discard);
         slot.sparams.seed              = json_value(data, "seed",              default_sparams.seed);
         slot.sparams.n_probs           = json_value(data, "n_probs",           default_sparams.n_probs);
@@ -1129,13 +1134,48 @@ struct server_context {
             SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
         }
 
-        // if we have already seen a new line, we stop after a certain time limit
-        if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
-            (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
-            slot.stopped_limit  = true;
-            slot.has_next_token = false;
+        if (slot.has_new_line) {
+            // if we have already seen a new line, we stop after a certain time limit
+            if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
+                slot.stopped_limit  = true;
+                slot.has_next_token = false;
+
+                SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
+            }
+
+            // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
+            if (slot.params.n_indent > 0) {
+                // check the current indentation
+                // TODO: improve by not doing it more than once for each new line
+                if (slot.last_nl_pos > 0) {
+                    size_t pos = slot.last_nl_pos;
+
+                    int n_indent = 0;
+                    while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
+                        n_indent++;
+                        pos++;
+                    }
+
+                    if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
+                        slot.stopped_limit  = true;
+                        slot.has_next_token = false;
+
+                        // cut the last line
+                        slot.generated_text.erase(pos, std::string::npos);
 
-            SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
+                        SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
+                    }
+                }
+
+                // find the next new line
+                {
+                    const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
+
+                    if (pos != std::string::npos) {
+                        slot.last_nl_pos = pos + 1;
+                    }
+                }
+            }
         }
 
         // check if there is a new line in the generated text