]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : infill sampling handle very long tokens (#9924)
authorGeorgi Gerganov <redacted>
Thu, 17 Oct 2024 19:32:47 +0000 (22:32 +0300)
committerGitHub <redacted>
Thu, 17 Oct 2024 19:32:47 +0000 (22:32 +0300)
* llama : infill sampling handle very long tokens

ggml-ci

* cont : better indices

ggml-ci

include/llama.h
src/llama-sampling.cpp
src/llama-vocab.cpp
src/llama.cpp

index 02bc7f087c62b71904865f83599114e0a05c460a..1a13360c21e3ae88de6e253c74131c7015969141 100644 (file)
@@ -953,12 +953,6 @@ extern "C" {
                                int32_t   lstrip,
                                   bool   special);
 
-    // check if token0 is contained as a prefix in token1
-    LLAMA_API bool llama_token_is_prefix(
-              const struct llama_model * model,
-                           llama_token   token0,
-                           llama_token   token1);
-
     /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
     /// @param text The char pointer must be large enough to hold the resulting text.
     /// @return Returns the number of chars/bytes on success, no more than text_len_max.
index 2e655068272b8adaa1fc21c4039baf2fbf6d393d..bd750c40ec65108197e8322533193fcbf45a23f1 100644 (file)
@@ -1745,6 +1745,9 @@ struct llama_sampler * llama_sampler_init_logit_bias(
 
 struct llama_sampler_infill {
     const struct llama_vocab * vocab;
+
+    std::vector<char> buf0;
+    std::vector<char> buf1;
 };
 
 static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
@@ -1810,27 +1813,44 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     size_t n_combined = 0; GGML_UNUSED(n_combined);
 
     // combine tokens with common prefix
-    for (size_t i = 0; i < cur_p->size; ++i) {
-        for (size_t j = 0; j < cur_p->size; ++j) {
-            if (cur_p->data[i].logit == -INFINITY) {
+    for (size_t i0 = 0; i0 < cur_p->size; ++i0) {
+        for (size_t i1 = 0; i1 < cur_p->size; ++i1) {
+            if (cur_p->data[i0].logit == -INFINITY) {
                 break;
             }
 
-            if (i == j || cur_p->data[j].logit == -INFINITY) {
+            if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) {
                 continue;
             }
 
-            if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
-                if (cur_p->data[i].p >  cur_p->data[j].p) {
-                    cur_p->data[i].p += cur_p->data[j].p;
-                    cur_p->data[j].logit = -INFINITY;
-                    cur_p->data[j].p     = 0.0f;
-                } else {
-                    cur_p->data[j].p += cur_p->data[i].p;
-                    cur_p->data[i].logit = -INFINITY;
-                    cur_p->data[i].p     = 0.0f;
+            int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+            if (len0 < 0) {
+                ctx->buf0.resize(len0);
+                len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+                assert(len0 > 0);
+            }
+
+            int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+            if (len1 < 0) {
+                ctx->buf1.resize(len1);
+                len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+                assert(len1 > 0);
+            }
+
+            // token i0 is a prefix of token i1
+            if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) {
+                int dst = i0;
+                int src = i1;
+
+                // merge into the token with higher probability
+                if (cur_p->data[i1].p > cur_p->data[i0].p) {
+                    std::swap(dst, src);
                 }
 
+                cur_p->data[dst].p += cur_p->data[src].p;
+                cur_p->data[src].logit = -INFINITY;
+                cur_p->data[src].p     = 0.0f;
+
                 n_combined++;
             }
         }
@@ -1936,6 +1956,8 @@ struct llama_sampler * llama_sampler_init_infill_impl(
         /* .iface = */ &llama_sampler_infill_i,
         /* .ctx   = */ new llama_sampler_infill {
             /* .vocab = */ &vocab,
+            /* .buf0 = */ std::vector<char>(512),
+            /* .buf1 = */ std::vector<char>(512),
         },
     };
 }
index 57d56a3d300e8969f139e7227a86036614995690..0a49ddbe3e291602fc64f74c197f3e4f657fedea 100644 (file)
@@ -1858,23 +1858,6 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
     return 0;
 }
 
-bool llama_token_is_prefix_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token0,
-                     llama_token   token1) {
-    char text_buf_0[128];
-    char text_buf_1[128];
-
-    const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
-    const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);
-
-    if (len0 <= 0 || len1 <= 0) {
-        return false;
-    }
-
-    return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
-}
-
 int32_t llama_detokenize_impl(
         const struct llama_vocab & vocab,
                const llama_token * tokens,
index 68479c6dba0495c0eef840184c6888e499d28b7a..d8e2b006c17ef2e231b8f2cc31bc887cd1aa4845 100644 (file)
@@ -21466,13 +21466,6 @@ int32_t llama_token_to_piece(
     return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
 }
 
-bool llama_token_is_prefix(
-    const struct llama_model * model,
-                 llama_token   token0,
-                 llama_token   token1) {
-    return llama_token_is_prefix_impl(model->vocab, token0, token1);
-}
-
 int32_t llama_detokenize(
     const struct llama_model * model,
            const llama_token * tokens,