]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
upgrade to llguidance 0.7.10 (#12576)
authorMichaƂ Moskal <redacted>
Wed, 26 Mar 2025 18:06:09 +0000 (11:06 -0700)
committerGitHub <redacted>
Wed, 26 Mar 2025 18:06:09 +0000 (11:06 -0700)
common/CMakeLists.txt
common/llguidance.cpp
tests/test-grammar-llguidance.cpp

index 17146fffc11685d6cfde7be474a66004673e4e36..829eb5b7238b9227792dec10bc64716ddf0a07b4 100644 (file)
@@ -114,8 +114,8 @@ if (LLAMA_LLGUIDANCE)
 
     ExternalProject_Add(llguidance_ext
         GIT_REPOSITORY https://github.com/guidance-ai/llguidance
-        # v0.6.12:
-        GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09
+        # v0.7.10:
+        GIT_TAG 0309d2a6bf40abda35344a362edc71e06d5009f8
         PREFIX ${CMAKE_BINARY_DIR}/llguidance
         SOURCE_DIR ${LLGUIDANCE_SRC}
         BUILD_IN_SOURCE TRUE
index 2feeb93c87e3018b66afd0e5e503f71d24be8355..8bff89ea4aa9e810819f62ffd93e19e09bf1020c 100644 (file)
@@ -11,25 +11,24 @@ struct llama_sampler_llg {
     std::string         grammar_kind;
     std::string         grammar_data;
     LlgTokenizer *      tokenizer;
-    LlgConstraint *     grammar;
-    LlgMaskResult       llg_res;
-    bool                has_llg_res;
+    LlgMatcher *        grammar;
 };
 
-static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
-                                             const char * grammar_data) {
+static LlgMatcher * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
+                                          const char * grammar_data) {
     LlgConstraintInit cinit;
     llg_constraint_init_set_defaults(&cinit, tokenizer);
     const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
     if (log_level && *log_level) {
         cinit.log_stderr_level = atoi(log_level);
     }
-    auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
-    if (llg_get_error(c)) {
-        LOG_ERR("llg error: %s\n", llg_get_error(c));
-        llg_free_constraint(c);
+    auto c = llg_new_matcher(&cinit, grammar_kind, grammar_data);
+    if (llg_matcher_get_error(c)) {
+        LOG_ERR("llg error: %s\n", llg_matcher_get_error(c));
+        llg_free_matcher(c);
         return nullptr;
     }
+
     return c;
 }
 
@@ -40,39 +39,29 @@ static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
 static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
     auto * ctx = (llama_sampler_llg *) smpl->ctx;
     if (ctx->grammar) {
-        LlgCommitResult res;
-        llg_commit_token(ctx->grammar, token, &res);
-        ctx->has_llg_res = false;
+        llg_matcher_consume_token(ctx->grammar, token);
     }
 }
 
 static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
     auto * ctx = (llama_sampler_llg *) smpl->ctx;
     if (ctx->grammar) {
-        if (!ctx->has_llg_res) {
-            if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
-                ctx->has_llg_res = true;
+        const uint32_t * mask = llg_matcher_get_mask(ctx->grammar);
+        if (mask == nullptr) {
+            if (llg_matcher_compute_mask(ctx->grammar) == 0) {
+                mask = llg_matcher_get_mask(ctx->grammar);
             } else {
-                LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
-                llg_free_constraint(ctx->grammar);
+                LOG_ERR("llg error: %s\n", llg_matcher_get_error(ctx->grammar));
+                llg_free_matcher(ctx->grammar);
                 ctx->grammar = nullptr;
+                return;
             }
         }
-        if (ctx->has_llg_res) {
-            if (ctx->llg_res.is_stop) {
-                for (size_t i = 0; i < cur_p->size; ++i) {
-                    if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
-                        cur_p->data[i].logit = -INFINITY;
-                    }
-                }
-            } else {
-                const uint32_t * mask = ctx->llg_res.sample_mask;
-                for (size_t i = 0; i < cur_p->size; ++i) {
-                    auto token = cur_p->data[i].id;
-                    if ((mask[token / 32] & (1 << (token % 32))) == 0) {
-                        cur_p->data[i].logit = -INFINITY;
-                    }
-                }
+
+        for (size_t i = 0; i < cur_p->size; ++i) {
+            auto token = cur_p->data[i].id;
+            if ((mask[token / 32] & (1 << (token % 32))) == 0) {
+                cur_p->data[i].logit = -INFINITY;
             }
         }
     }
@@ -80,14 +69,9 @@ static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array
 
 static void llama_sampler_llg_reset(llama_sampler * smpl) {
     auto * ctx = (llama_sampler_llg *) smpl->ctx;
-    if (!ctx->grammar) {
-        return;
+    if (ctx->grammar) {
+        llg_matcher_reset(ctx->grammar);
     }
-
-    auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
-    llg_free_constraint(ctx->grammar);
-    ctx->grammar     = grammar_new;
-    ctx->has_llg_res = false;
 }
 
 static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
@@ -102,7 +86,7 @@ static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
         if (ctx->grammar) {
             result_ctx->grammar_kind = ctx->grammar_kind;
             result_ctx->grammar_data = ctx->grammar_data;
-            result_ctx->grammar      = llg_clone_constraint(ctx->grammar);
+            result_ctx->grammar      = llg_clone_matcher(ctx->grammar);
             result_ctx->tokenizer    = llg_clone_tokenizer(ctx->tokenizer);
         }
     }
@@ -114,7 +98,7 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
     const auto * ctx = (llama_sampler_llg *) smpl->ctx;
 
     if (ctx->grammar) {
-        llg_free_constraint(ctx->grammar);
+        llg_free_matcher(ctx->grammar);
         llg_free_tokenizer(ctx->tokenizer);
     }
 
@@ -239,9 +223,11 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
             /* .grammar_data = */ grammar_data,
             /* .tokenizer    = */ tokenizer,
             /* .grammar      = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
-            /* .llg_res      = */ {},
-            /* .has_llg_res  = */ false,
         };
+        if (ctx->grammar) {
+            GGML_ASSERT(((size_t) llama_vocab_n_tokens(vocab) + 31) / 32 * 4 ==
+                        llg_matcher_get_mask_byte_size(ctx->grammar));
+        }
     } else {
         *ctx = {
             /* .vocab        = */ vocab,
@@ -249,15 +235,12 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
             /* .grammar_data = */ {},
             /* .tokenizer    = */ nullptr,
             /* .grammar      = */ nullptr,
-            /* .llg_res      = */ {},
-            /* .has_llg_res  = */ false,
         };
     }
 
     return llama_sampler_init(
         /* .iface = */ &llama_sampler_llg_i,
-        /* .ctx   = */ ctx
-    );
+        /* .ctx   = */ ctx);
 }
 
 #else
index 8b696006be714f4700b61f790146d6a1bd2792a3..3c19220e11964a5303ba8dd09dcb5bfec17dc769 100644 (file)
@@ -1086,6 +1086,65 @@ static void test_json_schema() {
         });
 }
 
+static void one_hot(llama_token_data_array & tok_arr, llama_token selected) {
+    auto n_vocab = tok_arr.size;
+
+    tok_arr.selected = -1;
+    tok_arr.sorted   = false;
+    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+        tok_arr.data[token_id].id    = token_id;
+        tok_arr.data[token_id].logit = 0.0f;
+    }
+
+    tok_arr.data[selected].logit = 100.0f;
+}
+
+static void test_sampler_chain(void) {
+    auto sparams            = llama_sampler_chain_default_params();
+    sparams.no_perf         = false;
+    llama_sampler * sampler = llama_sampler_chain_init(sparams);
+
+    const auto grammar_data = R"(%llguidance {}
+start: /[A-Z ]*/)";
+
+    llama_sampler_chain_add(sampler, llama_sampler_init_llg(vocab, "lark", grammar_data));
+    llama_sampler_chain_add(sampler, llama_sampler_init_dist(42));
+
+    auto input  = "ALL YOUR BASE ARE BELONG TO US";
+    auto tokens = common_tokenize(vocab, input, false, false);
+
+    auto n_vocab = llama_vocab_n_tokens(vocab);
+
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+        cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
+    }
+    auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
+
+    for (const auto token : tokens) {
+        one_hot(tok_arr, token);
+
+        fprintf(stderr, "applying token: %d\n", token);
+        llama_sampler_apply(sampler, &tok_arr);
+
+        auto idx = tok_arr.selected;
+        fprintf(stderr, " -> %d %f\n", cur[idx].id, cur[idx].logit);
+        assert(cur[tok_arr.selected].id == token);
+        llama_sampler_accept(sampler, token);
+    }
+
+    auto tok_eos = llama_vocab_eot(vocab);
+    if (tok_eos == LLAMA_TOKEN_NULL) {
+        tok_eos = llama_vocab_eos(vocab);
+    }
+
+    one_hot(tok_arr, tok_eos);
+
+    llama_sampler_apply(sampler, &tok_arr);
+    assert(cur[tok_arr.selected].id == tok_eos);
+}
+
 int main(int argc, const char ** argv) {
     fprintf(stdout, "Running llguidance integration tests...\n");
 
@@ -1135,6 +1194,9 @@ int main(int argc, const char ** argv) {
     test_special_chars();
     test_quantifiers();
     test_json_schema();
+
+    test_sampler_chain();
+
     fprintf(stdout, "All tests passed.\n");
     return 0;
 }