]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Extending grammar integration tests (#6644)
authorClint Herron <redacted>
Mon, 29 Apr 2024 18:40:14 +0000 (14:40 -0400)
committerGitHub <redacted>
Mon, 29 Apr 2024 18:40:14 +0000 (14:40 -0400)
* Cleaning up integration tests to share code between tests and make it simpler to add new tests.

* Add tests around quantifiers to ensure both matching and non-matching compliance.

* Add slightly more complex grammar with quantifiers to test references with quantifiers.

* Fixing build when C++17 is not present.

* Separating test calls to give more helpful stack traces on failure. Adding verbose messages to give visibility for what is being tested.

* Adding quotes around strings to explicitly show whitespace

* Removing trailing whitespace.

* Implementing suggestions from @ochafik -- grammars and test strings now print and flush before tests to aid in debugging segfaults and whatnot.

* Cleaning up forgotten symbols. Modifying simple test to use test harness. Added comments for more verbose descriptions of what each test is accomplishing.

* Unicode symbol modifications to hopefully make log easier to parse visually.

tests/test-grammar-integration.cpp

index 2d8f228e3769d72d91caf3cddc4154a8ff502ef1..1a4004e2ab1755d9bbfedcd3e6079b03994fa77f 100644 (file)
 #include "unicode.h"
 #include <cassert>
 #include <string>
+#include <vector>
 
-static void test_simple_grammar() {
-    // Test case for a simple grammar
-    const std::string grammar_str = R"""(root ::= expr
-expr ::= term ("+" term)*
-term ::= number
-number ::= [0-9]+)""";
-
-    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+static llama_grammar* build_grammar(const std::string & grammar_str) {
+    auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
 
     // Ensure we parsed correctly
     assert(!parsed_grammar.rules.empty());
@@ -30,8 +25,10 @@ number ::= [0-9]+)""";
     llama_grammar* grammar = llama_grammar_init(
         grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
 
-    std::string input = "123+456";
+    return grammar;
+}
 
+static bool match_string(const std::string & input, llama_grammar* grammar) {
     auto decoded = decode_utf8(input, {});
 
     const auto & code_points = decoded.first;
@@ -39,159 +36,67 @@ number ::= [0-9]+)""";
     for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
         auto prev_stacks = grammar->stacks;
         llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
-        assert(!grammar->stacks.empty());
+        if (grammar->stacks.empty()) {
+            // no stacks means that the grammar failed to match at this point
+            return false;
+        }
     }
 
-    bool completed_grammar = false;
-
     for (const auto & stack : grammar->stacks) {
         if (stack.empty()) {
-            completed_grammar = true;
-            break;
+            // An empty stack means that the grammar has been completed
+            return true;
         }
     }
 
-    assert(completed_grammar);
-
-    // Clean up allocated memory
-    llama_grammar_free(grammar);
+    return false;
 }
 
-static void test_complex_grammar() {
-    // Test case for a more complex grammar, with both failure strings and success strings
-    const std::string grammar_str = R"""(root ::= expression
-expression ::= term ws (("+"|"-") ws term)*
-term ::= factor ws (("*"|"/") ws factor)*
-factor ::= number | variable | "(" expression ")" | function-call
-number ::= [0-9]+
-variable ::= [a-zA-Z_][a-zA-Z0-9_]*
-function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
-ws ::= [ \t\n\r]?)""";
-
-    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
-
-    // Ensure we parsed correctly
-    assert(!parsed_grammar.rules.empty());
+static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
+    fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
+    fflush(stderr);
 
-    // Ensure we have a root node
-    assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
-
-    std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
-    llama_grammar* grammar = llama_grammar_init(
-        grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+    auto grammar = build_grammar(grammar_str);
 
     // Save the original grammar stacks so that we can reset after every new string we want to test
     auto original_stacks = grammar->stacks;
 
-    // Test a few strings
-    std::vector<std::string> test_strings_pass = {
-        "42",
-        "1*2*3*4*5",
-        "x",
-        "x+10",
-        "x1+y2",
-        "(a+b)*(c-d)",
-        "func()",
-        "func(x,y+2)",
-        "a*(b+c)-d/e",
-        "f(g(x),h(y,z))",
-        "x + 10",
-        "x1 + y2",
-        "(a + b) * (c - d)",
-        "func()",
-        "func(x, y + 2)",
-        "a * (b + c) - d / e",
-        "f(g(x), h(y, z))",
-        "123+456",
-        "123*456*789-123/456+789*123",
-        "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
-    };
-
-    std::vector<std::string> test_strings_fail = {
-        "+",
-        "/ 3x",
-        "x + + y",
-        "a * / b",
-        "func(,)",
-        "func(x y)",
-        "(a + b",
-        "x + y)",
-        "a + b * (c - d",
-        "42 +",
-        "x +",
-        "x + 10 +",
-        "(a + b) * (c - d",
-        "func(",
-        "func(x, y + 2",
-        "a * (b + c) - d /",
-        "f(g(x), h(y, z)",
-        "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
-    };
+    fprintf(stderr, "  🔵 Valid strings:\n");
 
     // Passing strings
-    for (const auto & test_string : test_strings_pass) {
-        auto decoded = decode_utf8(test_string, {});
-
-        const auto & code_points = decoded.first;
-
-        int pos = 0;
-        for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-            ++pos;
-            auto prev_stacks = grammar->stacks;
-            llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
-
-            // Expect that each code point will not cause the grammar to fail
-            if (grammar->stacks.empty()) {
-                fprintf(stdout, "Error at position %d\n", pos);
-                fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
-                fprintf(stderr, "Input string is %s:\n", test_string.c_str());
-            }
-            assert(!grammar->stacks.empty());
-        }
+    for (const auto & test_string : passing_strings) {
+        fprintf(stderr, "    \"%s\" ", test_string.c_str());
+        fflush(stderr);
 
-        bool completed_grammar = false;
+        bool matched = match_string(test_string, grammar);
 
-        for (const auto & stack : grammar->stacks) {
-            if (stack.empty()) {
-                completed_grammar = true;
-                break;
-            }
+        if (!matched) {
+            fprintf(stderr, "❌ (failed to match)\n");
+        } else {
+            fprintf(stdout, "✅︎\n");
         }
 
-        assert(completed_grammar);
+        assert(matched);
 
         // Reset the grammar stacks
         grammar->stacks = original_stacks;
     }
 
+    fprintf(stderr, "  🟠 Invalid strings:\n");
+
     // Failing strings
-    for (const auto & test_string : test_strings_fail) {
-        auto decoded = decode_utf8(test_string, {});
-
-        const auto & code_points = decoded.first;
-        bool parse_failed = false;
-
-        for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-            auto prev_stacks = grammar->stacks;
-            llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
-            if (grammar->stacks.empty()) {
-                parse_failed = true;
-                break;
-            }
-            assert(!grammar->stacks.empty());
-        }
+    for (const auto & test_string : failing_strings) {
+        fprintf(stderr, "    \"%s\" ", test_string.c_str());
+        fflush(stderr);
 
-        bool completed_grammar = false;
+        bool matched = match_string(test_string, grammar);
 
-        for (const auto & stack : grammar->stacks) {
-            if (stack.empty()) {
-                completed_grammar = true;
-                break;
-            }
+        if (matched) {
+            fprintf(stderr, "❌ (incorrectly matched)\n");
+        } else {
+            fprintf(stdout, "✅︎\n");
         }
-
-        // Ensure that the grammar is not completed, or that each string failed to match as-expected
-        assert((!completed_grammar) || parse_failed);
+        assert(!matched);
 
         // Reset the grammar stacks
         grammar->stacks = original_stacks;
@@ -201,7 +106,183 @@ ws ::= [ \t\n\r]?)""";
     llama_grammar_free(grammar);
 }
 
+static void test_simple_grammar() {
+    // Test case for a simple grammar
+    test_grammar(
+        "simple grammar",
+        R"""(
+            root ::= expr
+            expr ::= term ("+" term)*
+            term ::= number
+            number ::= [0-9]+)""",
+        // Passing strings
+        {
+            "42",
+            "1+2+3+4+5",
+            "123+456",
+        },
+        // Failing strings
+        {
+            "+",
+            "/ 3",
+            "1+2+3+4+5+",
+            "12a45",
+        }
+    );
+}
+
+static void test_complex_grammar() {
+    // Test case for a more complex grammar, with both failure strings and success strings
+    test_grammar(
+        "medium complexity grammar",
+        // Grammar
+        R"""(
+            root ::= expression
+            expression ::= term ws (("+"|"-") ws term)*
+            term ::= factor ws (("*"|"/") ws factor)*
+            factor ::= number | variable | "(" expression ")" | function-call
+            number ::= [0-9]+
+            variable ::= [a-zA-Z_][a-zA-Z0-9_]*
+            function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
+            ws ::= [ \t\n\r]?)""",
+        // Passing strings
+        {
+            "42",
+            "1*2*3*4*5",
+            "x",
+            "x+10",
+            "x1+y2",
+            "(a+b)*(c-d)",
+            "func()",
+            "func(x,y+2)",
+            "a*(b+c)-d/e",
+            "f(g(x),h(y,z))",
+            "x + 10",
+            "x1 + y2",
+            "(a + b) * (c - d)",
+            "func()",
+            "func(x, y + 2)",
+            "a * (b + c) - d / e",
+            "f(g(x), h(y, z))",
+            "123+456",
+            "123*456*789-123/456+789*123",
+            "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
+        },
+        // Failing strings
+        {
+            "+",
+            "/ 3x",
+            "x + + y",
+            "a * / b",
+            "func(,)",
+            "func(x y)",
+            "(a + b",
+            "x + y)",
+            "a + b * (c - d",
+            "42 +",
+            "x +",
+            "x + 10 +",
+            "(a + b) * (c - d",
+            "func(",
+            "func(x, y + 2",
+            "a * (b + c) - d /",
+            "f(g(x), h(y, z)",
+            "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
+        }
+    );
+}
+
+static void test_quantifiers() {
+    // A collection of tests to exercise * + and ? quantifiers
+
+    test_grammar(
+        "* quantifier",
+        // Grammar
+        R"""(root ::= "a"*)""",
+        // Passing strings
+        {
+            "",
+            "a",
+            "aaaaa",
+            "aaaaaaaaaaaaaaaaaa",
+            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+        },
+        // Failing strings
+        {
+            "b",
+            "ab",
+            "aab",
+            "ba",
+            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
+        }
+    );
+    test_grammar(
+        "+ quantifier",
+        // Grammar
+        R"""(root ::= "a"+)""",
+        // Passing strings
+        {
+            "a",
+            "aaaaa",
+            "aaaaaaaaaaaaaaaaaa",
+            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+        },
+        // Failing strings
+        {
+            "",
+            "b",
+            "ab",
+            "aab",
+            "ba",
+            "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
+        }
+    );
+    test_grammar(
+        "? quantifier",
+        // Grammar
+        R"""(root ::= "a"?)""",
+        // Passing strings
+        {
+            "",
+            "a"
+        },
+        // Failing strings
+        {
+            "b",
+            "ab",
+            "aa",
+            "ba",
+        }
+    );
+    test_grammar(
+        "mixed quantifiers",
+        // Grammar
+        R"""(
+            root ::= cons+ vowel* cons? (vowel cons)*
+            vowel ::= [aeiouy]
+            cons ::= [bcdfghjklmnpqrstvwxyz]
+            )""",
+        // Passing strings
+        {
+            "yes",
+            "no",
+            "noyes",
+            "crwth",
+            "four",
+            "bryyyy",
+        },
+        // Failing strings
+        {
+            "yess",
+            "yesno",
+            "forty",
+            "catyyy",
+        }
+    );
+}
+
 static void test_failure_missing_root() {
+    fprintf(stderr, "⚫ Testing missing root node:\n");
     // Test case for a grammar that is missing a root rule
     const std::string grammar_str = R"""(rot ::= expr
 expr ::= term ("+" term)*
@@ -215,29 +296,37 @@ number ::= [0-9]+)""";
 
     // Ensure we do NOT have a root node
     assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
+    fprintf(stderr, "  ✅︎ Passed\n");
 }
 
 static void test_failure_missing_reference() {
+    fprintf(stderr, "⚫ Testing missing reference node:\n");
+
     // Test case for a grammar that is missing a referenced rule
-    const std::string grammar_str = R"""(root ::= expr
+    const std::string grammar_str =
+R"""(root ::= expr
 expr ::= term ("+" term)*
 term ::= numero
 number ::= [0-9]+)""";
 
-    fprintf(stderr, "Expected error:  ");
+    fprintf(stderr, "    Expected error:  ");
 
     grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
 
     // Ensure we did NOT parsed correctly
     assert(parsed_grammar.rules.empty());
 
-    fprintf(stderr, "End of expected error. Test successful.\n");
+    fprintf(stderr, "    End of expected error.\n");
+    fprintf(stderr, "  ✅︎ Passed\n");
 }
 
 int main() {
+    fprintf(stdout, "Running grammar integration tests...\n");
     test_simple_grammar();
     test_complex_grammar();
+    test_quantifiers();
     test_failure_missing_root();
     test_failure_missing_reference();
+    fprintf(stdout, "All tests passed.\n");
     return 0;
 }