]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add SPM infill support (#8016)
authorSigbjørn Skjæret <redacted>
Fri, 28 Jun 2024 10:53:43 +0000 (12:53 +0200)
committerGitHub <redacted>
Fri, 28 Jun 2024 10:53:43 +0000 (12:53 +0200)
* add --spm-infill option

* support --spm-infill

* support --spm-infill

common/common.cpp
common/common.h
examples/infill/README.md
examples/infill/infill.cpp
examples/server/README.md
examples/server/server.cpp

index 57d03a5789edd614c5fa66f6207704b7b9887397..6a00d25be1316a1128a5dfc06714c01d215d407e 100644 (file)
@@ -1026,6 +1026,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         params.input_suffix = argv[i];
         return true;
     }
+    if (arg == "--spm-infill") {
+        params.spm_infill = true;
+        return true;
+    }
     if (arg == "--grammar") {
         CHECK_ARG
         sparams.grammar = argv[i];
@@ -1409,6 +1413,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
     options.push_back({ "main infill", "       --in-prefix-bos",        "prefix BOS to user inputs, preceding the `--in-prefix` string" });
     options.push_back({ "main infill", "       --in-prefix STRING",     "string to prefix user inputs with (default: empty)" });
     options.push_back({ "main infill", "       --in-suffix STRING",     "string to suffix after user inputs with (default: empty)" });
+    options.push_back({ "server infill",
+                                       "       --spm-infill",           "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });
 
     options.push_back({ "sampling" });
     options.push_back({ "*",           "       --samplers SAMPLERS",    "samplers that will be used for generation in the order, separated by \';\'\n"
index 0486ba3800ed754f95b5b89a393ae95dcdeb7a10..d6cb814b990e9ba8049377a11dcb26e85f194c16 100644 (file)
@@ -250,6 +250,8 @@ struct gpt_params {
     std::string cvector_outfile       = "control_vector.gguf";
     std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
     std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
+
+    bool spm_infill = false; // suffix/prefix/middle pattern for infill
 };
 
 void gpt_params_handle_model_default(gpt_params & params);
index 74f42d2fc26965cfc3cc525903563db563358681..810a0c5e76697ae4bafbfa7869f923eb38416882 100644 (file)
@@ -15,6 +15,7 @@ In this section, we cover the most commonly used options for running the `infill
 -   `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses.
 -   `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text.
 -   `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
+-   `--spm-infill`: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
 
 ## Input Prompts
 
index 3e82e4a81a20bca745b004264eaace04a2283af3..ca71dd687f30e00f9895e3f5de8c741647a70d53 100644 (file)
@@ -210,6 +210,7 @@ int main(int argc, char ** argv) {
         suff_rm_leading_spc = false;
     }
     std::vector<llama_token> embd_inp;
+    std::vector<llama_token> embd_end;
     std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
     std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
     const int space_token = 29871;
@@ -217,12 +218,13 @@ int main(int argc, char ** argv) {
         inp_sfx.erase(inp_sfx.begin());
     }
     inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
+    inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
+    embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
+    embd_end = params.spm_infill ? inp_pfx : inp_sfx;
     if (add_bos) {
-        inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
+        embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
     }
-    inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
-    embd_inp = inp_pfx;
-    embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
+    embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
 
     const llama_token middle_token = llama_token_middle(model);
     if (middle_token >= 0) {
@@ -526,14 +528,14 @@ int main(int argc, char ** argv) {
                     inp_sfx.erase(inp_sfx.begin());
                 }
                 inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
+                inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
+                embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
+                embd_end = params.spm_infill ? inp_pfx : inp_sfx;
                 if (add_bos) {
-                    inp_pfx.insert(inp_pfx.begin(), llama_token_bos(model));
+                    embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
                 }
-                inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
-                embd_inp = inp_pfx;
-                embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
+                embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
 
-                const llama_token middle_token = llama_token_middle(model);
                 if (middle_token >= 0) {
                     embd_inp.push_back(middle_token);
                 }
index e7fb0bf64c0e11f2a8067eae5fca95724c5dd2bd..4fab006bb9a43520f1785ababe348403953c520e 100644 (file)
@@ -73,6 +73,7 @@ The project is under active development, and we are [looking for feedback and co
 - `-fa`, `--flash-attn` : enable flash attention (default: disabled).
 - `-ctk TYPE`, `--cache-type-k TYPE` : KV cache data type for K (default: `f16`, options `f32`, `f16`, `q8_0`, `q4_0`, `q4_1`, `iq4_nl`, `q5_0`, or `q5_1`)
 - `-ctv TYPE`, `--cache-type-v TYPE` : KV cache type for V (default `f16`, see `-ctk` for options)
+- `--spm-infill` : Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
 
 **If compiled with `LLAMA_SERVER_SSL=ON`**
 - `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key
index ae768097baa0e5a3f5941f1eaa0e1830a7acd49d..d7fb61812aa3e73b8a86db8299e36c8f6dcd43d4 100644 (file)
@@ -2020,6 +2020,7 @@ struct server_context {
                         slot.t_start_generation = 0;
 
                         if (slot.infill) {
+                            const bool add_bos = llama_should_add_bos_token(model);
                             bool suff_rm_leading_spc = true;
                             if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
                                 params.input_suffix.erase(0, 1);
@@ -2035,16 +2036,21 @@ struct server_context {
                             }
 
                             prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
-                            prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
-                            prefix_tokens.insert(prefix_tokens.end(),   llama_token_suffix(model));
-                            prefix_tokens.insert(prefix_tokens.end(),   suffix_tokens.begin(), suffix_tokens.end());
+                            suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
+
+                            auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
+                            auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
+                            if (add_bos) {
+                                embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
+                            }
+                            embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
 
                             const llama_token middle_token = llama_token_middle(model);
                             if (middle_token >= 0) {
-                                prefix_tokens.push_back(middle_token);
+                                embd_inp.push_back(middle_token);
                             }
 
-                            prompt_tokens = prefix_tokens;
+                            prompt_tokens = embd_inp;
                         } else {
                             prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
                         }