]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Added support for . (any character) token in grammar engine. (#6467)
authorClint Herron <redacted>
Thu, 6 Jun 2024 13:08:52 +0000 (06:08 -0700)
committerGitHub <redacted>
Thu, 6 Jun 2024 13:08:52 +0000 (06:08 -0700)
* Added support for . (any characer) token in grammar engine.

* Add integration tests for any-character symbol.

common/grammar-parser.cpp
llama.cpp
llama.h
tests/test-grammar-integration.cpp

index 79d2b0354b90c359c117d6820540c1a37842d772..a518b766dc33e5a3bf5921cacbe62184363fffab 100644 (file)
@@ -266,6 +266,10 @@ namespace grammar_parser {
                     throw std::runtime_error(std::string("expecting ')' at ") + pos);
                 }
                 pos = parse_space(pos + 1, is_nested);
+            } else if (*pos == '.') { // any char
+                last_sym_start = out_elements.size();
+                out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
+                pos = parse_space(pos + 1, is_nested);
             } else if (*pos == '*') {
                 pos = parse_space(pos + 1, is_nested);
                 handle_repetitions(0, -1);
@@ -401,6 +405,7 @@ namespace grammar_parser {
             case LLAMA_GRETYPE_CHAR_NOT:       return true;
             case LLAMA_GRETYPE_CHAR_ALT:       return true;
             case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
+            case LLAMA_GRETYPE_CHAR_ANY:       return true;
             default:                           return false;
         }
     }
@@ -415,6 +420,7 @@ namespace grammar_parser {
                 case LLAMA_GRETYPE_CHAR_NOT:       fprintf(file, "CHAR_NOT");       break;
                 case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
                 case LLAMA_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
+                case LLAMA_GRETYPE_CHAR_ANY:       fprintf(file, "CHAR_ANY");       break;
             }
             switch (elem.type) {
                 case LLAMA_GRETYPE_END:
@@ -426,6 +432,7 @@ namespace grammar_parser {
                 case LLAMA_GRETYPE_CHAR_NOT:
                 case LLAMA_GRETYPE_CHAR_RNG_UPPER:
                 case LLAMA_GRETYPE_CHAR_ALT:
+                case LLAMA_GRETYPE_CHAR_ANY:
                     fprintf(file, "(\"");
                     print_grammar_char(file, elem.value);
                     fprintf(file, "\") ");
@@ -483,11 +490,15 @@ namespace grammar_parser {
                     }
                     print_grammar_char(file, elem.value);
                     break;
+                case LLAMA_GRETYPE_CHAR_ANY:
+                    fprintf(file, ".");
+                    break;
             }
             if (is_char_element(elem)) {
                 switch (rule[i + 1].type) {
                     case LLAMA_GRETYPE_CHAR_ALT:
                     case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+                    case LLAMA_GRETYPE_CHAR_ANY:
                         break;
                     default:
                         fprintf(file, "] ");
index cefb4d1d52dc38e159c8521595be247d63483edd..32264a0082a907c5b13c0db378737e9706b07406 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -13640,7 +13640,7 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
         const uint32_t                chr) {
 
     bool found            = false;
-    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
+    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
 
     GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
 
@@ -13649,6 +13649,10 @@ static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
             // inclusive range, e.g. [a-z]
             found = found || (pos->value <= chr && chr <= pos[1].value);
             pos += 2;
+        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
+            // Any character matches "."
+            found = true;
+            pos += 1;
         } else {
             // exact char match, e.g. [a] or "a"
             found = found || pos->value == chr;
@@ -13666,7 +13670,7 @@ static bool llama_grammar_match_partial_char(
         const llama_grammar_element * pos,
         const llama_partial_utf8      partial_utf8) {
 
-    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
+    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
     GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
 
     uint32_t partial_value = partial_utf8.value;
@@ -13696,6 +13700,9 @@ static bool llama_grammar_match_partial_char(
                 return is_positive_char;
             }
             pos += 2;
+        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
+            // Any character matches "."
+            return true;
         } else {
             // exact char match, e.g. [a] or "a"
             if (low <= pos->value && pos->value <= high) {
@@ -13756,6 +13763,7 @@ static void llama_grammar_advance_stack(
         }
         case LLAMA_GRETYPE_CHAR:
         case LLAMA_GRETYPE_CHAR_NOT:
+        case LLAMA_GRETYPE_CHAR_ANY:
             if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
                 // only add the stack if it's not a duplicate of one we already have
                 new_stacks.emplace_back(stack);
diff --git a/llama.h b/llama.h
index 9dcd67bef5036e427fa3ff21aaa44d39e82e2734..62908261f279164ee58c19d1f07a2b33f33da348 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -365,6 +365,9 @@ extern "C" {
         // modifies a preceding LLAMA_GRETYPE_CHAR or
         // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
         LLAMA_GRETYPE_CHAR_ALT       = 6,
+
+        // any character (.)
+        LLAMA_GRETYPE_CHAR_ANY       = 7,
     };
 
     typedef struct llama_grammar_element {
index 9bdab05af72593abb698c44786de9a5ace664d3a..8787fb1ec6987c6298f57c0ef41ba2e626c70813 100644 (file)
@@ -205,6 +205,33 @@ static void test_complex_grammar() {
     );
 }
 
+static void test_special_chars() {
+    // A collection of tests to exercise special characters such as "."
+    test_grammar(
+        "special characters",
+        // Grammar
+        R"""(
+            root ::= ... "abc" ...
+            )""",
+        // Passing strings
+        {
+            "abcabcabc",
+            "aaaabcccc",
+            // NOTE: Also ensures that multi-byte characters still count as a single character
+            "🔵🟠✅abc❌🟠🔵"
+        },
+        // Failing strings
+        {
+            "aaabcccc",
+            "aaaaabcccc",
+            "aaaabccc",
+            "aaaabccccc",
+            "🔵🟠✅❌abc❌✅🟠🔵"
+            "🔵🟠abc🟠🔵"
+        }
+    );
+}
+
 static void test_quantifiers() {
     // A collection of tests to exercise * + and ? quantifiers
 
@@ -445,6 +472,7 @@ int main() {
     fprintf(stdout, "Running grammar integration tests...\n");
     test_simple_grammar();
     test_complex_grammar();
+    test_special_chars();
     test_quantifiers();
     test_failure_missing_root();
     test_failure_missing_reference();