]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
grammars: 1.5x faster inference w/ complex grammars (vector reserves / reuses) (...
authorOlivier Chafik <redacted>
Thu, 11 Apr 2024 18:47:34 +0000 (19:47 +0100)
committerGitHub <redacted>
Thu, 11 Apr 2024 18:47:34 +0000 (19:47 +0100)
* grammars: reserve rejects & next candidates

* grammars: reuse new_stacks

* grammars: fix missing sig change in llama.h

* grammars: fix test (api changed)

* grammars: update gbnf-validator.cpp

* grammars: simpler syntax (no swap)

examples/gbnf-validator/gbnf-validator.cpp
llama.cpp
llama.h
tests/test-grammar-integration.cpp

index e4c0c1689c7a49e4f62f62fa1c400cbd05f4c8e4..091069ffa699c9cd37351a5184a3049012ccfce1 100644 (file)
@@ -17,7 +17,7 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
     size_t pos = 0;
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
         auto prev_stacks = grammar->stacks;
-        grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+        llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
         if (grammar->stacks.empty()) {
             error_pos = pos;
             error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
index b6e2ade9134d95949af095578f9dd13626c35751..ad07059c4533a287b3172e47cbecb022efc4c3a4 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -11912,12 +11912,13 @@ static void llama_grammar_advance_stack(
 // be positioned at a character range (see `llama_grammar_advance_stack`), and
 // produces the N possible stacks if the given char is accepted at those
 // positions
-std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
+void llama_grammar_accept(
         const std::vector<std::vector<llama_grammar_element>>         & rules,
         const std::vector<std::vector<const llama_grammar_element *>> & stacks,
-        const uint32_t                                                  chr) {
+        const uint32_t                                                  chr,
+        std::vector<std::vector<const llama_grammar_element *>>       & new_stacks) {
 
-    std::vector<std::vector<const llama_grammar_element *>> new_stacks;
+    new_stacks.clear();
 
     for (const auto & stack : stacks) {
         if (stack.empty()) {
@@ -11936,8 +11937,6 @@ std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
             llama_grammar_advance_stack(rules, new_stack, new_stacks);
         }
     }
-
-    return new_stacks;
 }
 
 static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
@@ -11951,6 +11950,7 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
         const std::vector<llama_grammar_candidate>            & candidates) {
 
     std::vector<llama_grammar_candidate> rejects;
+    rejects.reserve(candidates.size());
 
     if (stack.empty()) {
         for (const auto & tok : candidates) {
@@ -11964,6 +11964,8 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_
     const llama_grammar_element * stack_pos = stack.back();
 
     std::vector<llama_grammar_candidate> next_candidates;
+    next_candidates.reserve(candidates.size());
+
     for (const auto & tok : candidates) {
         if (*tok.code_points == 0) {
             // reached end of full codepoints in token, reject iff it ended in a partial sequence
@@ -12771,8 +12773,10 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
     // Note terminating 0 in decoded string
     const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
     const auto & code_points = decoded.first;
+    std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-        grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+        llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
+        grammar->stacks = tmp_new_stacks;
     }
     grammar->partial_utf8 = decoded.second;
     GGML_ASSERT(!grammar->stacks.empty());
diff --git a/llama.h b/llama.h
index b770a275ff02fbf4ce664a4194a1a3b66478ad6c..b5da686f7b7e5af5f88b6a2066064fec7276e91c 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -1097,10 +1097,11 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
     struct llama_context * ctx
 );
 
-std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
+void llama_grammar_accept(
         const std::vector<std::vector<llama_grammar_element>>         & rules,
         const std::vector<std::vector<const llama_grammar_element *>> & stacks,
-        const uint32_t                                                  chr);
+        const uint32_t                                                  chr,
+        std::vector<std::vector<const llama_grammar_element *>>       & new_stacks);
 
 std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
         const std::string & src,
index 0a9c3b6f5f7c3664051b5702adff6fb8cc7038b1..2d8f228e3769d72d91caf3cddc4154a8ff502ef1 100644 (file)
@@ -38,7 +38,7 @@ number ::= [0-9]+)""";
 
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
         auto prev_stacks = grammar->stacks;
-        grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+        llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
         assert(!grammar->stacks.empty());
     }
 
@@ -138,7 +138,7 @@ ws ::= [ \t\n\r]?)""";
         for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
             ++pos;
             auto prev_stacks = grammar->stacks;
-            grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+            llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
 
             // Expect that each code point will not cause the grammar to fail
             if (grammar->stacks.empty()) {
@@ -173,7 +173,7 @@ ws ::= [ \t\n\r]?)""";
 
         for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
             auto prev_stacks = grammar->stacks;
-            grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+            llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
             if (grammar->stacks.empty()) {
                 parse_failed = true;
                 break;