]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : add infill sampler (#9896)
authorGeorgi Gerganov <redacted>
Tue, 15 Oct 2024 13:35:33 +0000 (16:35 +0300)
committerGitHub <redacted>
Tue, 15 Oct 2024 13:35:33 +0000 (16:35 +0300)
ggml-ci

common/common.h
common/sampling.cpp
examples/main/main.cpp
include/llama.h
src/llama-sampling.cpp
src/llama-sampling.h
src/llama-vocab.cpp
src/llama-vocab.h
src/llama.cpp

index df2ee6bd43a3312cc74f94ab82896e833b44efe1..5ca8fd391ab7425f7064cd3ee3093c6af8df66e2 100644 (file)
@@ -91,7 +91,7 @@ enum common_sampler_type {
     COMMON_SAMPLER_TYPE_TYPICAL_P   = 5,
     COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
     COMMON_SAMPLER_TYPE_XTC         = 7,
-
+    COMMON_SAMPLER_TYPE_INFILL      = 8,
 };
 
 // dimensionality reduction methods, used by cvector-generator
@@ -136,7 +136,7 @@ struct common_sampler_params {
         COMMON_SAMPLER_TYPE_TOP_P,
         COMMON_SAMPLER_TYPE_MIN_P,
         COMMON_SAMPLER_TYPE_XTC,
-        COMMON_SAMPLER_TYPE_TEMPERATURE
+        COMMON_SAMPLER_TYPE_TEMPERATURE,
     };
 
     std::string grammar; // optional BNF-like grammar to constrain sampling
index fb95bcd3bf2b098187c59fe8caa8a5c5c920cefb..56cd0df6b81bc8adae41a1e1be13eb2ed2e24e6f 100644 (file)
@@ -196,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                     case COMMON_SAMPLER_TYPE_TEMPERATURE:
                         llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
                         break;
+                    case COMMON_SAMPLER_TYPE_INFILL:
+                        llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
+                        break;
                     default:
                         GGML_ASSERT(false && "unknown sampler type");
                 }
@@ -376,6 +379,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_MIN_P:       return 'm';
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
         case COMMON_SAMPLER_TYPE_XTC:         return 'x';
+        case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
         default : return '?';
     }
 }
@@ -389,6 +393,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_MIN_P:       return "min_p";
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
         case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
+        case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
         default : return "";
     }
 }
@@ -402,6 +407,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
         { "tfs_z",       COMMON_SAMPLER_TYPE_TFS_Z },
         { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
         { "xtc",         COMMON_SAMPLER_TYPE_XTC },
+        { "infill",      COMMON_SAMPLER_TYPE_INFILL },
     };
 
     // since samplers names are written multiple ways
@@ -448,7 +454,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P),       COMMON_SAMPLER_TYPE_TOP_P },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P),       COMMON_SAMPLER_TYPE_MIN_P },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
-        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC }
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
     };
 
     std::vector<common_sampler_type> samplers;
index fb10c20c5e36d32b2da17d63378a8e0738ae992a..65483c45f3f92fc1e1312ad31f1372dd1ac8daac 100644 (file)
@@ -569,30 +569,30 @@ int main(int argc, char ** argv) {
                     if (!params.ctx_shift){
                         LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
                         break;
-                    } else {
-                        if (params.n_predict == -2) {
-                            LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
-                            break;
-                        }
+                    }
+
+                    if (params.n_predict == -2) {
+                        LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
+                        break;
+                    }
 
-                        const int n_left    = n_past - params.n_keep;
-                        const int n_discard = n_left/2;
+                    const int n_left    = n_past - params.n_keep;
+                    const int n_discard = n_left/2;
 
-                        LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
-                                n_past, n_left, n_ctx, params.n_keep, n_discard);
+                    LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
+                            n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                        llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
-                        llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
+                    llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
+                    llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
 
-                        n_past -= n_discard;
+                    n_past -= n_discard;
 
-                        LOG_DBG("after swap: n_past = %d\n", n_past);
+                    LOG_DBG("after swap: n_past = %d\n", n_past);
 
-                        LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
+                    LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
 
-                        LOG_DBG("clear session path\n");
-                        path_session.clear();
-                    }
+                    LOG_DBG("clear session path\n");
+                    path_session.clear();
                 }
             } else {
                 // context extension via Self-Extend
index 92d4c70c13b876ba2253750900ee4ffe6fae877e..02bc7f087c62b71904865f83599114e0a05c460a 100644 (file)
@@ -953,6 +953,12 @@ extern "C" {
                                int32_t   lstrip,
                                   bool   special);
 
+    // check if token0 is contained as a prefix in token1
+    LLAMA_API bool llama_token_is_prefix(
+              const struct llama_model * model,
+                           llama_token   token0,
+                           llama_token   token1);
+
     /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
     /// @param text The char pointer must be large enough to hold the resulting text.
     /// @return Returns the number of chars/bytes on success, no more than text_len_max.
@@ -1148,6 +1154,28 @@ extern "C" {
                              int32_t   n_logit_bias,
               const llama_logit_bias * logit_bias);
 
+    // this sampler is meant to be used for fill-in-the-middle infilling
+    // it's supposed to be used after top_k + top_p sampling
+    //
+    // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
+    // 2. combine probs of tokens that have the same prefix
+    //
+    // example:
+    //
+    // - before:
+    //   "hel":   0.5
+    //   "hell":  0.2
+    //   "hello": 0.1
+    //   "dummy": 0.1
+    //
+    // - after:
+    //   "hel":   0.8
+    //   "dummy": 0.1
+    //
+    // 3. discard non-EOG tokens with low prob
+    // 4. if no tokens are left -> pick EOT
+    //
+    LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
 
     // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
     LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
index 67a78c3ac4fe8eba7ca20a0c4d7c78d609571723..2e655068272b8adaa1fc21c4039baf2fbf6d393d 100644 (file)
@@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
     };
 }
 
+// infill
+
+//#define GGML_DEBUG_SAMPLER_INFILL
+
+struct llama_sampler_infill {
+    const struct llama_vocab * vocab;
+};
+
+static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
+    return "infill";
+}
+
+static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_infill *) smpl->ctx;
+
+    llama_sampler_softmax_impl(cur_p);
+
+#if defined(GGML_DEBUG_SAMPLER_INFILL)
+#define LOG_DBG_CUR LLAMA_LOG_DEBUG
+#else
+#define LOG_DBG_CUR(...)
+#endif
+
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
+    }
+
+    float p_txt_sum = 0.0f;
+    float p_eog_sum = 0.0f;
+
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+            p_eog_sum += cur_p->data[i].p;
+        } else {
+            p_txt_sum += cur_p->data[i].p;
+        }
+    }
+
+    const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);
+
+    LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);
+
+    if (3*p_eog_sum*cur_p->size > p_txt_sum) {
+        LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);
+
+        // keep just the EOG tokens
+        const auto size_org = cur_p->size;
+
+        cur_p->size = 0;
+
+        float p_sum = 0.0f;
+
+        for (size_t i = 0; i < size_org; ++i) {
+            if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+                p_sum += cur_p->data[i].p;
+
+                cur_p->data[cur_p->size++] = cur_p->data[i];
+            }
+        }
+
+        // normalize probs
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            cur_p->data[i].p /= p_sum;
+        }
+
+        return;
+    }
+
+    size_t n_combined = 0; GGML_UNUSED(n_combined);
+
+    // combine tokens with common prefix
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        for (size_t j = 0; j < cur_p->size; ++j) {
+            if (cur_p->data[i].logit == -INFINITY) {
+                break;
+            }
+
+            if (i == j || cur_p->data[j].logit == -INFINITY) {
+                continue;
+            }
+
+            if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
+                if (cur_p->data[i].p >  cur_p->data[j].p) {
+                    cur_p->data[i].p += cur_p->data[j].p;
+                    cur_p->data[j].logit = -INFINITY;
+                    cur_p->data[j].p     = 0.0f;
+                } else {
+                    cur_p->data[j].p += cur_p->data[i].p;
+                    cur_p->data[i].logit = -INFINITY;
+                    cur_p->data[i].p     = 0.0f;
+                }
+
+                n_combined++;
+            }
+        }
+    }
+
+    size_t n_non_eog = 0;
+
+    size_t size_org = cur_p->size;
+
+    float p_sum = 0.0f;
+    float thold = 0.2f;
+
+    cur_p->size = 0;
+
+    LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
+
+    for (size_t i = 0; i < size_org; ++i) {
+        const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+
+        if (cur_p->data[i].p < thold && !is_eog) {
+            continue;
+        }
+
+        if (!is_eog) {
+            ++n_non_eog;
+        }
+
+        p_sum += cur_p->data[i].p;
+
+        // keep this token
+        cur_p->data[cur_p->size++] = cur_p->data[i];
+    }
+
+    LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);
+
+    // if no non-EOG tokens are left -> reduce cur_p to single EOT token
+    if (n_non_eog == 0) {
+        cur_p->size = 1;
+        cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
+        cur_p->data[0].logit = 1.0f;
+
+        return;
+    }
+
+    // normalize probs
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].p /= p_sum;
+
+        LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
+    }
+
+    size_org = cur_p->size;
+    p_sum = 0.0f;
+    thold = 1.0/(n_non_eog + 1);
+
+    cur_p->size = 0;
+
+    LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
+
+    for (size_t i = 0; i < size_org; ++i) {
+        const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+
+        if (cur_p->data[i].p < thold && !is_eog) {
+            continue;
+        }
+
+        p_sum += cur_p->data[i].p;
+
+        cur_p->data[cur_p->size++] = cur_p->data[i];
+    }
+
+    // normalize probs
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        cur_p->data[i].p /= p_sum;
+
+        LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
+    }
+
+#undef LOG_DBG_CUR
+}
+
+static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
+    return llama_sampler_init_infill_impl(*ctx->vocab);
+}
+
+static void llama_sampler_infill_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_infill *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_infill_i = {
+    /* .name   = */ llama_sampler_infill_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_infill_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_infill_clone,
+    /* .free   = */ llama_sampler_infill_free,
+};
+
+struct llama_sampler * llama_sampler_init_infill_impl(
+        const struct llama_vocab & vocab) {
+    return new llama_sampler {
+        /* .iface = */ &llama_sampler_infill_i,
+        /* .ctx   = */ new llama_sampler_infill {
+            /* .vocab = */ &vocab,
+        },
+    };
+}
+
 // utils
 
 uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
index d90b147130e4b10d691d797dc99137569ca22a6e..2683f1b92696f1a7c30a88282f327a07e145cf08 100644 (file)
@@ -4,8 +4,6 @@
 
 #include "llama-grammar.h"
 
-#include <unordered_map>
-
 struct llama_vocab;
 struct llama_grammar;
 
@@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
         const struct llama_vocab & vocab,
                       const char * grammar_str,
                       const char * grammar_root);
+
+struct llama_sampler * llama_sampler_init_infill_impl(
+        const struct llama_vocab & vocab);
index a27394a3772314a2bcf4b765cdf6415b55d8aef1..070de936536e0d1bf1ca4c3c27a667cd3cbc766d 100644 (file)
@@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
     return 0;
 }
 
+bool llama_token_is_prefix_impl(
+        const struct llama_vocab & vocab,
+                     llama_token   token0,
+                     llama_token   token1) {
+    char text_buf_0[128];
+    char text_buf_1[128];
+
+    const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
+    const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
+
+    if (len0 <= 0 || len1 <= 0) {
+        return false;
+    }
+
+    return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
+}
+
 int32_t llama_detokenize_impl(
         const struct llama_vocab & vocab,
                const llama_token * tokens,
index 17e14488a4d52688c83eba1212fa4b3b28c8f904..d958d0073be95699b5c5b0e3c85e4d6ce1a01154 100644 (file)
@@ -48,7 +48,7 @@ struct llama_vocab {
     id special_cls_id  = LLAMA_TOKEN_NULL;
     id special_mask_id = LLAMA_TOKEN_NULL;
 
-    id linefeed_id    = 13;
+    id linefeed_id = 13;
 
     // fim tokens
     id special_fim_pre_id = LLAMA_TOKEN_NULL;
@@ -149,6 +149,12 @@ int32_t llama_token_to_piece_impl(
                          int32_t   lstrip,
                             bool   special);
 
+// check if token0 is contained as a prefix in token1
+bool llama_token_is_prefix_impl(
+        const struct llama_vocab & vocab,
+                     llama_token   token0,
+                     llama_token   token1);
+
 int32_t llama_detokenize_impl(
         const struct llama_vocab & vocab,
                const llama_token * tokens,
index 511f91802d939065235c798efcb6217e4e690e40..8d44c73c8c95c1dce612e25f350043dac6ecc19a 100644 (file)
@@ -21500,6 +21500,13 @@ int32_t llama_token_to_piece(
     return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
 }
 
+bool llama_token_is_prefix(
+    const struct llama_model * model,
+                 llama_token   token0,
+                 llama_token   token1) {
+    return llama_token_is_prefix_impl(model->vocab, token0, token1);
+}
+
 int32_t llama_detokenize(
     const struct llama_model * model,
            const llama_token * tokens,
@@ -21830,6 +21837,10 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
     return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
 }
 
+struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
+    return llama_sampler_init_infill_impl(model->vocab);
+}
+
 //
 // model split
 //