#include "unicode.h"
#include <cassert>
#include <string>
+#include <vector>
-static void test_simple_grammar() {
- // Test case for a simple grammar
- const std::string grammar_str = R"""(root ::= expr
-expr ::= term ("+" term)*
-term ::= number
-number ::= [0-9]+)""";
-
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
+static llama_grammar* build_grammar(const std::string & grammar_str) {
+ auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we parsed correctly
assert(!parsed_grammar.rules.empty());
llama_grammar* grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
- std::string input = "123+456";
+ return grammar;
+}
+static bool match_string(const std::string & input, llama_grammar* grammar) {
auto decoded = decode_utf8(input, {});
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
- assert(!grammar->stacks.empty());
+ if (grammar->stacks.empty()) {
+ // no stacks means that the grammar failed to match at this point
+ return false;
+ }
}
- bool completed_grammar = false;
-
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
- completed_grammar = true;
- break;
+ // An empty stack means that the grammar has been completed
+ return true;
}
}
- assert(completed_grammar);
-
- // Clean up allocated memory
- llama_grammar_free(grammar);
+ return false;
}
-static void test_complex_grammar() {
- // Test case for a more complex grammar, with both failure strings and success strings
- const std::string grammar_str = R"""(root ::= expression
-expression ::= term ws (("+"|"-") ws term)*
-term ::= factor ws (("*"|"/") ws factor)*
-factor ::= number | variable | "(" expression ")" | function-call
-number ::= [0-9]+
-variable ::= [a-zA-Z_][a-zA-Z0-9_]*
-function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
-ws ::= [ \t\n\r]?)""";
-
- grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
-
- // Ensure we parsed correctly
- assert(!parsed_grammar.rules.empty());
+static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
+ fprintf(stderr, "⚫ Testing %s. Grammar: %s\n", test_desc.c_str(), grammar_str.c_str());
+ fflush(stderr);
- // Ensure we have a root node
- assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()));
-
- std::vector<const llama_grammar_element*> grammar_rules(parsed_grammar.c_rules());
- llama_grammar* grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+ auto grammar = build_grammar(grammar_str);
// Save the original grammar stacks so that we can reset after every new string we want to test
auto original_stacks = grammar->stacks;
- // Test a few strings
- std::vector<std::string> test_strings_pass = {
- "42",
- "1*2*3*4*5",
- "x",
- "x+10",
- "x1+y2",
- "(a+b)*(c-d)",
- "func()",
- "func(x,y+2)",
- "a*(b+c)-d/e",
- "f(g(x),h(y,z))",
- "x + 10",
- "x1 + y2",
- "(a + b) * (c - d)",
- "func()",
- "func(x, y + 2)",
- "a * (b + c) - d / e",
- "f(g(x), h(y, z))",
- "123+456",
- "123*456*789-123/456+789*123",
- "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
- };
-
- std::vector<std::string> test_strings_fail = {
- "+",
- "/ 3x",
- "x + + y",
- "a * / b",
- "func(,)",
- "func(x y)",
- "(a + b",
- "x + y)",
- "a + b * (c - d",
- "42 +",
- "x +",
- "x + 10 +",
- "(a + b) * (c - d",
- "func(",
- "func(x, y + 2",
- "a * (b + c) - d /",
- "f(g(x), h(y, z)",
- "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
- };
+ fprintf(stderr, " 🔵 Valid strings:\n");
// Passing strings
- for (const auto & test_string : test_strings_pass) {
- auto decoded = decode_utf8(test_string, {});
-
- const auto & code_points = decoded.first;
-
- int pos = 0;
- for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- ++pos;
- auto prev_stacks = grammar->stacks;
- llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
-
- // Expect that each code point will not cause the grammar to fail
- if (grammar->stacks.empty()) {
- fprintf(stdout, "Error at position %d\n", pos);
- fprintf(stderr, "Unexpected character '%s'\n", unicode_cpt_to_utf8(*it).c_str());
- fprintf(stderr, "Input string is %s:\n", test_string.c_str());
- }
- assert(!grammar->stacks.empty());
- }
+ for (const auto & test_string : passing_strings) {
+ fprintf(stderr, " \"%s\" ", test_string.c_str());
+ fflush(stderr);
- bool completed_grammar = false;
+ bool matched = match_string(test_string, grammar);
- for (const auto & stack : grammar->stacks) {
- if (stack.empty()) {
- completed_grammar = true;
- break;
- }
+ if (!matched) {
+ fprintf(stderr, "❌ (failed to match)\n");
+ } else {
+ fprintf(stdout, "✅︎\n");
}
- assert(completed_grammar);
+ assert(matched);
// Reset the grammar stacks
grammar->stacks = original_stacks;
}
+ fprintf(stderr, " 🟠 Invalid strings:\n");
+
// Failing strings
- for (const auto & test_string : test_strings_fail) {
- auto decoded = decode_utf8(test_string, {});
-
- const auto & code_points = decoded.first;
- bool parse_failed = false;
-
- for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
- auto prev_stacks = grammar->stacks;
- llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
- if (grammar->stacks.empty()) {
- parse_failed = true;
- break;
- }
- assert(!grammar->stacks.empty());
- }
+ for (const auto & test_string : failing_strings) {
+ fprintf(stderr, " \"%s\" ", test_string.c_str());
+ fflush(stderr);
- bool completed_grammar = false;
+ bool matched = match_string(test_string, grammar);
- for (const auto & stack : grammar->stacks) {
- if (stack.empty()) {
- completed_grammar = true;
- break;
- }
+ if (matched) {
+ fprintf(stderr, "❌ (incorrectly matched)\n");
+ } else {
+ fprintf(stdout, "✅︎\n");
}
-
- // Ensure that the grammar is not completed, or that each string failed to match as-expected
- assert((!completed_grammar) || parse_failed);
+ assert(!matched);
// Reset the grammar stacks
grammar->stacks = original_stacks;
llama_grammar_free(grammar);
}
+static void test_simple_grammar() {
+ // Test case for a simple grammar
+ test_grammar(
+ "simple grammar",
+ R"""(
+ root ::= expr
+ expr ::= term ("+" term)*
+ term ::= number
+ number ::= [0-9]+)""",
+ // Passing strings
+ {
+ "42",
+ "1+2+3+4+5",
+ "123+456",
+ },
+ // Failing strings
+ {
+ "+",
+ "/ 3",
+ "1+2+3+4+5+",
+ "12a45",
+ }
+ );
+}
+
+static void test_complex_grammar() {
+ // Test case for a more complex grammar, with both failure strings and success strings
+ test_grammar(
+ "medium complexity grammar",
+ // Grammar
+ R"""(
+ root ::= expression
+ expression ::= term ws (("+"|"-") ws term)*
+ term ::= factor ws (("*"|"/") ws factor)*
+ factor ::= number | variable | "(" expression ")" | function-call
+ number ::= [0-9]+
+ variable ::= [a-zA-Z_][a-zA-Z0-9_]*
+ function-call ::= variable ws "(" (expression ("," ws expression)*)? ")"
+ ws ::= [ \t\n\r]?)""",
+ // Passing strings
+ {
+ "42",
+ "1*2*3*4*5",
+ "x",
+ "x+10",
+ "x1+y2",
+ "(a+b)*(c-d)",
+ "func()",
+ "func(x,y+2)",
+ "a*(b+c)-d/e",
+ "f(g(x),h(y,z))",
+ "x + 10",
+ "x1 + y2",
+ "(a + b) * (c - d)",
+ "func()",
+ "func(x, y + 2)",
+ "a * (b + c) - d / e",
+ "f(g(x), h(y, z))",
+ "123+456",
+ "123*456*789-123/456+789*123",
+ "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456"
+ },
+ // Failing strings
+ {
+ "+",
+ "/ 3x",
+ "x + + y",
+ "a * / b",
+ "func(,)",
+ "func(x y)",
+ "(a + b",
+ "x + y)",
+ "a + b * (c - d",
+ "42 +",
+ "x +",
+ "x + 10 +",
+ "(a + b) * (c - d",
+ "func(",
+ "func(x, y + 2",
+ "a * (b + c) - d /",
+ "f(g(x), h(y, z)",
+ "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
+ }
+ );
+}
+
+static void test_quantifiers() {
+ // A collection of tests to exercise * + and ? quantifiers
+
+ test_grammar(
+ "* quantifier",
+ // Grammar
+ R"""(root ::= "a"*)""",
+ // Passing strings
+ {
+ "",
+ "a",
+ "aaaaa",
+ "aaaaaaaaaaaaaaaaaa",
+ "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ },
+ // Failing strings
+ {
+ "b",
+ "ab",
+ "aab",
+ "ba",
+ "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
+ }
+ );
+ test_grammar(
+ "+ quantifier",
+ // Grammar
+ R"""(root ::= "a"+)""",
+ // Passing strings
+ {
+ "a",
+ "aaaaa",
+ "aaaaaaaaaaaaaaaaaa",
+ "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
+ },
+ // Failing strings
+ {
+ "",
+ "b",
+ "ab",
+ "aab",
+ "ba",
+ "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab"
+ }
+ );
+ test_grammar(
+ "? quantifier",
+ // Grammar
+ R"""(root ::= "a"?)""",
+ // Passing strings
+ {
+ "",
+ "a"
+ },
+ // Failing strings
+ {
+ "b",
+ "ab",
+ "aa",
+ "ba",
+ }
+ );
+ test_grammar(
+ "mixed quantifiers",
+ // Grammar
+ R"""(
+ root ::= cons+ vowel* cons? (vowel cons)*
+ vowel ::= [aeiouy]
+ cons ::= [bcdfghjklmnpqrstvwxyz]
+ )""",
+ // Passing strings
+ {
+ "yes",
+ "no",
+ "noyes",
+ "crwth",
+ "four",
+ "bryyyy",
+ },
+ // Failing strings
+ {
+ "yess",
+ "yesno",
+ "forty",
+ "catyyy",
+ }
+ );
+}
+
static void test_failure_missing_root() {
+ fprintf(stderr, "⚫ Testing missing root node:\n");
// Test case for a grammar that is missing a root rule
const std::string grammar_str = R"""(rot ::= expr
expr ::= term ("+" term)*
// Ensure we do NOT have a root node
assert(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end());
+ fprintf(stderr, " ✅︎ Passed\n");
}
static void test_failure_missing_reference() {
+ fprintf(stderr, "⚫ Testing missing reference node:\n");
+
// Test case for a grammar that is missing a referenced rule
- const std::string grammar_str = R"""(root ::= expr
+ const std::string grammar_str =
+R"""(root ::= expr
expr ::= term ("+" term)*
term ::= numero
number ::= [0-9]+)""";
- fprintf(stderr, "Expected error: ");
+ fprintf(stderr, " Expected error: ");
grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str());
// Ensure we did NOT parsed correctly
assert(parsed_grammar.rules.empty());
- fprintf(stderr, "End of expected error. Test successful.\n");
+ fprintf(stderr, " End of expected error.\n");
+ fprintf(stderr, " ✅︎ Passed\n");
}
int main() {
+ fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar();
test_complex_grammar();
+ test_quantifiers();
test_failure_missing_root();
test_failure_missing_reference();
+ fprintf(stdout, "All tests passed.\n");
return 0;
}