]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : accept extra_context for the infill endpoint (#9874)
authorGeorgi Gerganov <redacted>
Sun, 13 Oct 2024 18:31:35 +0000 (21:31 +0300)
committerGitHub <redacted>
Sun, 13 Oct 2024 18:31:35 +0000 (21:31 +0300)
* server : accept extra_context for the infill endpoint

ggml-ci

* server : update readme [no ci]

* server : use repo-level FIM pattern if possible

ggml-ci

examples/server/README.md
examples/server/server.cpp
src/llama.cpp

index cd0eaf847f769113de72f07c1e9cf2e6cbbdd758..eb0a7b32ef8890dc644f75f192fcec21bc2b31f4 100644 (file)
@@ -524,9 +524,30 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
 
 - `input_prefix`: Set the prefix of the code to infill.
 - `input_suffix`: Set the suffix of the code to infill.
+- `prompt`: Added after the `FIM_MID` token
+- `extra_context`: Additional context inserted before the FIM prefix. See https://github.com/ggerganov/llama.cpp/pull/9874
 
 It also accepts all the options of `/completion`.
 
+If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used:
+
+```txt
+<FIM_REP>myproject
+<FIM_SEP>{chunk 0 filename}
+{chunk 0 text}
+<FIM_SEP>{chunk 1 filename}
+{chunk 1 text}
+...
+<FIM_SEP>filename
+<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
+```
+
+If the tokens are missing, then the extra context is simply prefixed at the start:
+
+```txt
+[extra_context]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
+```
+
 ### **GET** `/props`: Get server global properties.
 
 This endpoint is public (no API key check). By default, it is read-only. To make POST request to change global properties, you need to start server with `--props`
index 015b3b2c56a31569e2c3a6fb48ea2b659b1b0a11..18bcad3f06bca0c418a5c469db424db4483759a2 100644 (file)
@@ -139,6 +139,7 @@ struct slot_params {
 
     json input_prefix;
     json input_suffix;
+    json extra_context;
 };
 
 struct server_slot {
@@ -170,6 +171,7 @@ struct server_slot {
 
     // when a task is submitted, we first tokenize the prompt and store it here
     std::vector<llama_token> prompt_tokens;
+    std::vector<llama_token> extra_tokens;
 
     std::string generated_text;
     std::vector<llama_token> cache_tokens;
@@ -906,8 +908,26 @@ struct server_context {
         }
 
         // infill
-        slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
-        slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
+        slot.params.input_prefix  = json_value(data, "input_prefix",  default_params.input_prefix);
+        slot.params.input_suffix  = json_value(data, "input_suffix",  default_params.input_suffix);
+        slot.params.extra_context = json_value(data, "extra_context", default_params.extra_context);
+
+        SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.params.extra_context.size());
+        for (const auto & chunk : slot.params.extra_context) {
+            // { "text": string, "filename": string }
+            if (!chunk.contains("text") || !chunk["text"].is_string()) {
+                send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
+                return false;
+            }
+
+            // filename is optional
+            if (chunk.contains("filename") && !chunk["filename"].is_string()) {
+                send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
+                return false;
+            }
+
+            SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
+        }
 
         // get prompt
         if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
@@ -1934,13 +1954,66 @@ struct server_context {
                                 } break;
                             case SERVER_TASK_CMPL_TYPE_INFILL:
                                 {
+                                    // use FIM repo-level pattern:
+                                    // ref: https://arxiv.org/pdf/2409.12186
+                                    //
+                                    // [FIM_REP]myproject
+                                    // [FIM_SEP]filename0
+                                    // extra chunk 0
+                                    // [FIM_SEP]filename1
+                                    // extra chunk 1
+                                    // ...
+                                    // [FIM_SEP]filename
+                                    // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
+                                    //
                                     auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
                                     auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
 
-                                    // for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
-                                    const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
+                                    slot.extra_tokens.clear();
+                                    if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
+                                        static const auto k_fim_repo = tokenize("myproject\n", false, false);
+
+                                        slot.extra_tokens.push_back(llama_token_fim_rep(model));
+                                        slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
+                                    }
+
+                                    for (const auto & chunk : slot.params.extra_context) {
+                                        // { "text": string, "filename": string }
+                                        const std::string text     = chunk.value("text", "");
+                                        const std::string filename = chunk.value("filename", "tmp");
+
+                                        if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
+                                            const auto k_fim_file = tokenize(filename + "\n", false, false);
+
+                                            slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
+                                            slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
+                                        } else {
+                                            // chunk separator in binary form to avoid confusing the AI
+                                            static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
+                                            static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
+
+                                            slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
+                                        }
+
+                                        const auto chunk_tokens = tokenize(text, false, false);
+                                        slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
+                                    }
+
+                                    if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
+                                        // TODO: current filename
+                                        static const auto k_fim_file = tokenize("filename\n", false, false);
+
+                                        slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
+                                        slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
+                                    }
+
+                                    // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
+                                    const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch)/4);
                                     const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
 
+                                    // fill the rest of the context with extra chunks
+                                    const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
+
                                     prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
                                     suffix_tokens.resize(n_suffix_take);
 
@@ -1954,6 +2027,11 @@ struct server_context {
                                         embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
                                     }
 
+                                    SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
+
+                                    // put the extra context before the FIM prefix
+                                    embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
+
                                     embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
                                     embd_inp.push_back(llama_token_fim_mid(model));
 
@@ -2058,11 +2136,15 @@ struct server_context {
 
                                     while (head_c < slot.cache_tokens.size() &&
                                            head_p < prompt_tokens.size()) {
-                                        if (llama_token_is_control(model, slot.cache_tokens[head_c])) {
+                                        if (llama_token_is_control(model, slot.cache_tokens[head_c]) &&
+                                            slot.cache_tokens[head_c] != llama_token_fim_rep(model) &&
+                                            slot.cache_tokens[head_c] != llama_token_fim_sep(model)) {
                                             break;
                                         }
 
-                                        if (llama_token_is_control(model, prompt_tokens[head_p])) {
+                                        if (llama_token_is_control(model, prompt_tokens[head_p]) &&
+                                            prompt_tokens[head_p] != llama_token_fim_rep(model) &&
+                                            prompt_tokens[head_p] != llama_token_fim_sep(model)) {
                                             break;
                                         }
 
@@ -2071,11 +2153,15 @@ struct server_context {
                                         while (head_c + n_match < slot.cache_tokens.size() &&
                                                head_p + n_match < prompt_tokens.size()     &&
                                                slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
-                                            if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match])) {
+                                            if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
+                                                slot.cache_tokens[head_c + n_match] != llama_token_fim_rep(model) &&
+                                                slot.cache_tokens[head_c + n_match] != llama_token_fim_sep(model)) {
                                                 break;
                                             }
 
-                                            if (llama_token_is_control(model, prompt_tokens[head_p + n_match])) {
+                                            if (llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
+                                                prompt_tokens[head_p + n_match] != llama_token_fim_rep(model) &&
+                                                prompt_tokens[head_p + n_match] != llama_token_fim_sep(model)) {
                                                 break;
                                             }
 
index f68024f5bd2b7a03469bd8489b1dc898e99c2cc6..511f91802d939065235c798efcb6217e4e690e40 100644 (file)
@@ -6596,8 +6596,8 @@ static void llm_load_vocab(
                    ) {
                     vocab.special_eot_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.first.c_str());
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
@@ -6610,8 +6610,8 @@ static void llm_load_vocab(
                         ) {
                     vocab.special_eom_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.first.c_str());
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
@@ -6627,8 +6627,8 @@ static void llm_load_vocab(
                         ) {
                     vocab.special_fim_pre_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.first.c_str());
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
@@ -6644,8 +6644,8 @@ static void llm_load_vocab(
                         ) {
                     vocab.special_fim_suf_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.first.c_str());
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
@@ -6661,8 +6661,8 @@ static void llm_load_vocab(
                         ) {
                     vocab.special_fim_mid_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.first.c_str());
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
@@ -6677,8 +6677,8 @@ static void llm_load_vocab(
                         ) {
                     vocab.special_fim_pad_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.first.c_str());
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
@@ -6694,8 +6694,8 @@ static void llm_load_vocab(
                         ) {
                     vocab.special_fim_rep_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.first.c_str());
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
@@ -6708,8 +6708,8 @@ static void llm_load_vocab(
                         ) {
                     vocab.special_fim_sep_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.first.c_str());
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
@@ -6720,6 +6720,19 @@ static void llm_load_vocab(
         // this is currently determined based on the token text, which is obviously not ideal
         // ref: https://github.com/ggerganov/llama.cpp/issues/9606
         vocab.special_eog_ids.clear();
+
+        if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
+        }
+
+        if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
+        }
+
+        if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
+        }
+
         for (const auto & t : vocab.token_to_id) {
             if (false
                     || t.first == "<|eot_id|>"
@@ -6732,13 +6745,20 @@ static void llm_load_vocab(
                ) {
                 vocab.special_eog_ids.insert(t.second);
                 if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                    LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                            __func__, t.first.c_str());
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
                     vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                 }
+            } else {
+                // token is control, but not marked as EOG -> print a warning
+                if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_WARN("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
             }
         }
 
+        // sanity checks
         if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
             vocab.special_eog_ids.insert(vocab.special_eos_id);
             LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);