]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : minor grammar refactor (#10897)
authorGeorgi Gerganov <redacted>
Thu, 19 Dec 2024 15:42:13 +0000 (17:42 +0200)
committerGitHub <redacted>
Thu, 19 Dec 2024 15:42:13 +0000 (17:42 +0200)
ggml-ci

examples/gbnf-validator/gbnf-validator.cpp
src/llama-grammar.cpp
src/llama-grammar.h
tests/test-grammar-integration.cpp
tests/test-llama-grammar.cpp

index 7493af9d3aec32aef2c9e3330eb75d65b516f634..17a0e27c444e86dd52956555256cef99838daaa0 100644 (file)
 static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
     const auto cpts = unicode_cpts_from_utf8(input_str);
 
-    const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar);
-          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
+    auto & stacks_cur = llama_grammar_get_stacks(grammar);
 
     size_t pos = 0;
     for (const auto & cpt : cpts) {
-        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
-
-        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
+        llama_grammar_accept(grammar, cpt);
 
         if (stacks_cur.empty()) {
             error_pos = pos;
             error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
-            stacks_cur = stacks_prev;
             return false;
         }
         ++pos;
@@ -82,7 +78,8 @@ int main(int argc, char** argv) {
 
     llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
     if (grammar == nullptr) {
-        throw std::runtime_error("Failed to initialize llama_grammar");
+        fprintf(stdout, "Failed to initialize llama_grammar\n");
+        return 1;
     }
     // Read the input file
     std::string input_str;
index 74e9f64b393b2f2e144f78b1e30830771e91099b..76d0cb3a2ff78b5a8e1293ad5e4564bda4c56316 100644 (file)
@@ -822,15 +822,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
     return grammar->stacks;
 }
 
-void llama_grammar_accept(
-        const llama_grammar_rules  & rules,
-        const llama_grammar_stacks & stacks,
-        const uint32_t               chr,
-              llama_grammar_stacks & stacks_new) {
-    stacks_new.clear();
-    stacks_new.reserve(stacks.size());
+void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
+    llama_grammar_stacks stacks_new;
+    stacks_new.reserve(grammar->stacks.size());
 
-    for (const auto & stack : stacks) {
+    for (const auto & stack : grammar->stacks) {
         if (stack.empty()) {
             continue;
         }
@@ -844,9 +840,11 @@ void llama_grammar_accept(
             if (!llama_grammar_is_end_of_sequence(pos)) {
                 new_stack.push_back(pos);
             }
-            llama_grammar_advance_stack(rules, new_stack, stacks_new);
+            llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
         }
     }
+
+    grammar->stacks = std::move(stacks_new);
 }
 
 llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
@@ -1051,7 +1049,12 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
 }
 
 struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) {
-    llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };
+    llama_grammar * result = new llama_grammar {
+        grammar.vocab,
+        grammar.rules,
+        grammar.stacks,
+        grammar.partial_utf8,
+    };
 
     // redirect elements in stacks to point to new rules
     for (size_t is = 0; is < result->stacks.size(); is++) {
@@ -1059,7 +1062,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
             for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
                 for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
                     if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
-                         result->stacks[is][ie]  =  &result->rules[ir0][ir1];
+                        result->stacks[is][ie] =  &result->rules[ir0][ir1];
                     }
                 }
             }
@@ -1126,11 +1129,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
     const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
     const auto & code_points = decoded.first;
 
-    llama_grammar_stacks stacks_new;
-
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-        llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new);
-        grammar.stacks = std::move(stacks_new);
+        llama_grammar_accept(&grammar, *it);
     }
 
     grammar.partial_utf8 = decoded.second;
index f529ce351e4167d03cbeb538018047a6287c1c02..13e940fb52e24f9880293ffcb248f19a9ca5eff9 100644 (file)
@@ -58,6 +58,7 @@ using llama_grammar_rules      = std::vector<llama_grammar_rule>;
 using llama_grammar_stacks     = std::vector<llama_grammar_stack>;
 using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
 
+// TODO: remove, needed for tests atm
 const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
       llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
 
@@ -65,11 +66,7 @@ const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar
 // be positioned at a character range (see `llama_grammar_advance_stack`), and
 // produces the N possible stacks if the given char is accepted at those
 // positions
-void llama_grammar_accept(
-        const llama_grammar_rules  & rules,
-        const llama_grammar_stacks & stacks,
-                          uint32_t   chr,
-              llama_grammar_stacks & stacks_new);
+void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr);
 
 std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
         const llama_grammar_rules      & rules,
index 5cc0cdb04751ff776dc1a727bfc1edbd1c6edc70..e1bdbb9250fca4ac93d1523c861885aa3a262f70 100644 (file)
@@ -32,13 +32,10 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
 static bool match_string(const std::string & input, llama_grammar * grammar) {
     const auto cpts = unicode_cpts_from_utf8(input);
 
-    const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar);
-          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
+    auto & stacks_cur = llama_grammar_get_stacks(grammar);
 
     for (const auto & cpt : cpts) {
-        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
-
-        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
+        llama_grammar_accept(grammar, cpt);
 
         if (stacks_cur.empty()) {
             // no stacks means that the grammar failed to match at this point
@@ -63,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
     auto * grammar = build_grammar(grammar_str);
 
     // Save the original grammar stacks so that we can reset after every new string we want to test
-    const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
+    const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy
 
     llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 
index 6f1374ca8ed58f1d1ecdeefbe1db42a32ce22fa7..e2129206be15650cf7f156d866cd1d395a8574be 100644 (file)
@@ -113,12 +113,10 @@ int main()
         }
     }
 
-    llama_grammar * grammar = NULL;
     std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
 
-    grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
-    if (grammar == nullptr)
-    {
+    llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+    if (grammar == nullptr) {
         throw std::runtime_error("Failed to initialize llama_grammar");
     }