]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Tests: Added integration tests for GBNF parser (#6472)
authorClint Herron <redacted>
Sat, 6 Apr 2024 14:31:33 +0000 (10:31 -0400)
committerGitHub <redacted>
Sat, 6 Apr 2024 14:31:33 +0000 (10:31 -0400)
* Added integration tests for GBNF parser to validate correctness of parsing, as well as correctness of string matching. Intended for use to pin behavior while working on performance improvements.

* Fixing whitespace errors and cleaning error message alert to be clearer.

* Removing hacky include to llama.cpp from grammar integration test now that needed functions are available via internal API.

* Comment cleanup.

* Reorganizing tests for readability.

* Cleaning up debug message to make a bit more sense.

Makefile
tests/CMakeLists.txt
tests/test-grammar-integration.cpp [new file with mode: 0644]

index bdd5ef335ef94f4171298478d7ce0675bd2d63ad..11b31c5c84182b568138d04ba70ae5068a253dd9 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -10,7 +10,7 @@ TEST_TARGETS = \
        tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0-llama          \
        tests/test-tokenizer-0-falcon tests/test-tokenizer-1-llama tests/test-tokenizer-1-bpe tests/test-rope      \
        tests/test-backend-ops tests/test-model-load-cancel tests/test-autorelease                                 \
-       tests/test-json-schema-to-grammar
+       tests/test-json-schema-to-grammar tests/test-grammar-integration
 
 # Code coverage output files
 COV_TARGETS = *.gcno tests/*.gcno *.gcda tests/*.gcda *.gcov tests/*.gcov lcov-report gcovr-report
@@ -918,6 +918,10 @@ tests/test-grammar-parser: tests/test-grammar-parser.cpp ggml.o llama.o grammar-
        $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 
+tests/test-grammar-integration: tests/test-grammar-integration.cpp ggml.o llama.o grammar-parser.o $(OBJS)
+       $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
+       $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
+
 tests/test-double-float: tests/test-double-float.cpp ggml.o $(OBJS)
        $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
index a43439aed96e72b70ab5f16a1af6f231f4e1e044..b5d7bb59c60dbef59243b87378903eeda3c86db0 100644 (file)
@@ -59,6 +59,7 @@ llama_test(test-tokenizer-1-bpe.cpp    NAME test-tokenizer-1-gpt2             AR
 
 llama_test(test-grammar-parser.cpp)
 llama_test(test-llama-grammar.cpp)
+llama_test(test-grammar-integration.cpp)
 llama_test(test-grad0.cpp)
 # llama_test(test-opt.cpp) # SLOW
 llama_test(test-backend-ops.cpp)
diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp
new file mode 100644 (file)
index 0000000..0a9c3b6
--- /dev/null
@@ -0,0 +1,243 @@
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#define LLAMA_API_INTERNAL
+
+#include "ggml.h"
+#include "llama.h"
+#include "grammar-parser.h"
+#include "unicode.h"
+#include <cassert>
+#include <string>
+
+static void test_simple_grammar() {
+    // Test case for a simple grammar
+    const std::string grammar_str = R"""(root ::= expr
+expr ::= term ("+" term)*
+term ::= number
+number ::= [0-9]+)""";
+
+    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+
+    // Ensure we parsed correctly
+    assert(!parsed_grammar.rules.empty());
+
+    // Ensure we have a root node
+    assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
+
+    std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
+    llama_grammar* grammar = llama_grammar_init(
+        grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+
+    std::string input = "123+456";
+
+    auto decoded = decode_utf8(input, {});
+
+    const auto & code_points = decoded.first;
+
+    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+        auto prev_stacks = grammar->stacks;
+        grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+        assert(!grammar->stacks.empty());
+    }
+
+    bool completed_grammar = false;
+
+    for (const auto & stack : grammar->stacks) {
+        if (stack.empty()) {
+            completed_grammar = true;
+            break;
+        }
+    }
+
+    assert(completed_grammar);
+
+    // Clean up allocated memory
+    llama_grammar_free(grammar);
+}
+
+static void test_complex_grammar() {
+    // Test case for a more complex grammar, with both failure strings and success strings
+    const std::string grammar_str = R"""(root ::= expression
+expression ::= term ws (("+"|"-") ws term)*
+term ::= factor ws (("*"|"/") ws factor)*
+factor ::= number | variable | "(" expression ")" | function-call
+number ::= [0-9]+
+variable ::= [a-zA-Z_][a-zA-Z0-9_]*
+function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
+ws ::= [ \t\n\r]?)""";
+
+    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+
+    // Ensure we parsed correctly
+    assert(!parsed_grammar.rules.empty());
+
+    // Ensure we have a root node
+    assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
+
+    std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
+    llama_grammar* grammar = llama_grammar_init(
+        grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+
+    // Save the original grammar stacks so that we can reset after every new string we want to test
+    auto original_stacks = grammar->stacks;
+
+    // Test a few strings
+    std::vector<std::string> test_strings_pass = {
+        "42",
+        "1*2*3*4*5",
+        "x",
+        "x+10",
+        "x1+y2",
+        "(a+b)*(c-d)",
+        "func()",
+        "func(x,y+2)",
+        "a*(b+c)-d/e",
+        "f(g(x),h(y,z))",
+        "x + 10",
+        "x1 + y2",
+        "(a + b) * (c - d)",
+        "func()",
+        "func(x, y + 2)",
+        "a * (b + c) - d / e",
+        "f(g(x), h(y, z))",
+        "123+456",
+        "123*456*789-123/456+789*123",
+        "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
+    };
+
+    std::vector<std::string> test_strings_fail = {
+        "+",
+        "/ 3x",
+        "x + + y",
+        "a * / b",
+        "func(,)",
+        "func(x y)",
+        "(a + b",
+        "x + y)",
+        "a + b * (c - d",
+        "42 +",
+        "x +",
+        "x + 10 +",
+        "(a + b) * (c - d",
+        "func(",
+        "func(x, y + 2",
+        "a * (b + c) - d /",
+        "f(g(x), h(y, z)",
+        "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
+    };
+
+    // Passing strings
+    for (const auto & test_string : test_strings_pass) {
+        auto decoded = decode_utf8(test_string, {});
+
+        const auto & code_points = decoded.first;
+
+        int pos = 0;
+        for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+            ++pos;
+            auto prev_stacks = grammar->stacks;
+            grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+
+            // Expect that each code point will not cause the grammar to fail
+            if (grammar->stacks.empty()) {
+                fprintf(stdout, "Error at position %d\n", pos);
+                fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
+                fprintf(stderr, "Input string is %s:\n", test_string.c_str());
+            }
+            assert(!grammar->stacks.empty());
+        }
+
+        bool completed_grammar = false;
+
+        for (const auto & stack : grammar->stacks) {
+            if (stack.empty()) {
+                completed_grammar = true;
+                break;
+            }
+        }
+
+        assert(completed_grammar);
+
+        // Reset the grammar stacks
+        grammar->stacks = original_stacks;
+    }
+
+    // Failing strings
+    for (const auto & test_string : test_strings_fail) {
+        auto decoded = decode_utf8(test_string, {});
+
+        const auto & code_points = decoded.first;
+        bool parse_failed = false;
+
+        for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+            auto prev_stacks = grammar->stacks;
+            grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
+            if (grammar->stacks.empty()) {
+                parse_failed = true;
+                break;
+            }
+            assert(!grammar->stacks.empty());
+        }
+
+        bool completed_grammar = false;
+
+        for (const auto & stack : grammar->stacks) {
+            if (stack.empty()) {
+                completed_grammar = true;
+                break;
+            }
+        }
+
+        // Ensure that the grammar is not completed, or that each string failed to match as-expected
+        assert((!completed_grammar) || parse_failed);
+
+        // Reset the grammar stacks
+        grammar->stacks = original_stacks;
+    }
+
+    // Clean up allocated memory
+    llama_grammar_free(grammar);
+}
+
+static void test_failure_missing_root() {
+    // Test case for a grammar that is missing a root rule
+    const std::string grammar_str = R"""(rot ::= expr
+expr ::= term ("+" term)*
+term ::= number
+number ::= [0-9]+)""";
+
+    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+
+    // Ensure we parsed correctly
+    assert(!parsed_grammar.rules.empty());
+
+    // Ensure we do NOT have a root node
+    assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
+}
+
+static void test_failure_missing_reference() {
+    // Test case for a grammar that is missing a referenced rule
+    const std::string grammar_str = R"""(root ::= expr
+expr ::= term ("+" term)*
+term ::= numero
+number ::= [0-9]+)""";
+
+    fprintf(stderr, "Expected error:  ");
+
+    grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+
+    // Ensure we did NOT parsed correctly
+    assert(parsed_grammar.rules.empty());
+
+    fprintf(stderr, "End of expected error. Test successful.\n");
+}
+
+int main() {
+    test_simple_grammar();
+    test_complex_grammar();
+    test_failure_missing_root();
+    test_failure_missing_reference();
+    return 0;
+}