]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
test : add simple grammar parsing tests (#2594)
authordrbh <redacted>
Sun, 13 Aug 2023 14:00:48 +0000 (10:00 -0400)
committerGitHub <redacted>
Sun, 13 Aug 2023 14:00:48 +0000 (17:00 +0300)
* adds simple grammar parsing tests

* adds cassert header

.gitignore
Makefile
tests/CMakeLists.txt
tests/test-grammar-parser.cpp [new file with mode: 0644]

index e345e64ed91e4a7e883e311d54bc29d431204ecc..743b8a8b6e0916196de740b0a71391823efd19e5 100644 (file)
@@ -70,6 +70,7 @@ poetry.lock
 poetry.toml
 
 # Test binaries
+tests/test-grammar-parser
 tests/test-double-float
 tests/test-grad0
 tests/test-opt
index ce593edfc0aa1df8ac641ce80f2c29837c52cd44..070ae124282065e356b81f9145d91c904b21049c 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -2,7 +2,7 @@
 BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test
 
 # Binaries only useful for tests
-TEST_TARGETS = tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
+TEST_TARGETS = tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
 
 default: $(BUILD_TARGETS)
 
@@ -412,6 +412,9 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o
 vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
        $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
 
+tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+       $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
+
 tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o common.o $(OBJS)
        $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
 
index 1a40edbec58c4c0bb4714a55824a82350430bf70..689fb6f2afe5a1b73d3251b987d8ad024509ca25 100644 (file)
@@ -11,5 +11,6 @@ llama_add_test(test-quantize-fns.cpp)
 llama_add_test(test-quantize-perf.cpp)
 llama_add_test(test-sampling.cpp)
 llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
+llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp)
 llama_add_test(test-grad0.cpp) # SLOW
 # llama_add_test(test-opt.cpp) # SLOW
diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp
new file mode 100644 (file)
index 0000000..7022988
--- /dev/null
@@ -0,0 +1,249 @@
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include "llama.h"
+#include "examples/grammar-parser.cpp"
+#include <cassert>
+
+int main()
+{
+    grammar_parser::parse_state parsed_grammar;
+
+    const char *grammar_bytes = R"""(root  ::= (expr "=" term "\n")+
+expr  ::= term ([-+*/] term)*
+term  ::= [0-9]+)""";
+
+    parsed_grammar = grammar_parser::parse(grammar_bytes);
+
+    std::vector<std::pair<std::string, uint32_t>> expected = {
+        {"expr", 2},
+        {"expr_5", 5},
+        {"expr_6", 6},
+        {"root", 0},
+        {"root_1", 1},
+        {"root_4", 4},
+        {"term", 3},
+        {"term_7", 7},
+    };
+
+    uint32_t index = 0;
+    for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
+    {
+        std::string key = it->first;
+        uint32_t value = it->second;
+        std::pair<std::string, uint32_t> expected_pair = expected[index];
+
+        // pretty print error message before asserting
+        if (expected_pair.first != key || expected_pair.second != value)
+        {
+            fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second);
+            fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value);
+            fprintf(stderr, "expected_pair != actual_pair\n");
+        }
+
+        assert(expected_pair.first == key && expected_pair.second == value);
+
+        index++;
+    }
+    std::vector<llama_grammar_element> expected_rules = {
+        {LLAMA_GRETYPE_RULE_REF, 4},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 2},
+        {LLAMA_GRETYPE_CHAR, 61},
+        {LLAMA_GRETYPE_RULE_REF, 3},
+        {LLAMA_GRETYPE_CHAR, 10},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 3},
+        {LLAMA_GRETYPE_RULE_REF, 6},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 7},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 1},
+        {LLAMA_GRETYPE_RULE_REF, 4},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_RULE_REF, 1},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_CHAR, 45},
+        {LLAMA_GRETYPE_CHAR_ALT, 43},
+        {LLAMA_GRETYPE_CHAR_ALT, 42},
+        {LLAMA_GRETYPE_CHAR_ALT, 47},
+        {LLAMA_GRETYPE_RULE_REF, 3},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 5},
+        {LLAMA_GRETYPE_RULE_REF, 6},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_CHAR, 48},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+        {LLAMA_GRETYPE_RULE_REF, 7},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_CHAR, 48},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+        {LLAMA_GRETYPE_END, 0},
+    };
+
+    index = 0;
+    for (auto rule : parsed_grammar.rules)
+    {
+        // compare rule to expected rule
+        for (uint32_t i = 0; i < rule.size(); i++)
+        {
+            llama_grammar_element element = rule[i];
+            llama_grammar_element expected_element = expected_rules[index];
+
+            // pretty print error message before asserting
+            if (expected_element.type != element.type || expected_element.value != element.value)
+            {
+                fprintf(stderr, "index: %d\n", index);
+                fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value);
+                fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value);
+                fprintf(stderr, "expected_element != actual_element\n");
+            }
+
+            assert(expected_element.type == element.type && expected_element.value == element.value);
+            index++;
+        }
+    }
+
+    const char *longer_grammar_bytes = R"""(
+    root  ::= (expr "=" ws term "\n")+
+    expr  ::= term ([-+*/] term)*
+    term  ::= ident | num | "(" ws expr ")" ws
+    ident ::= [a-z] [a-z0-9_]* ws
+    num   ::= [0-9]+ ws
+    ws    ::= [ \t\n]*
+    )""";
+
+    parsed_grammar = grammar_parser::parse(longer_grammar_bytes);
+
+    expected = {
+        {"expr", 2},
+        {"expr_6", 6},
+        {"expr_7", 7},
+        {"ident", 8},
+        {"ident_10", 10},
+        {"num", 9},
+        {"num_11", 11},
+        {"root", 0},
+        {"root_1", 1},
+        {"root_5", 5},
+        {"term", 4},
+        {"ws", 3},
+        {"ws_12", 12},
+    };
+
+    index = 0;
+    for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it)
+    {
+        std::string key = it->first;
+        uint32_t value = it->second;
+        std::pair<std::string, uint32_t> expected_pair = expected[index];
+
+        // pretty print error message before asserting
+        if (expected_pair.first != key || expected_pair.second != value)
+        {
+            fprintf(stderr, "expected_pair: %s, %d\n", expected_pair.first.c_str(), expected_pair.second);
+            fprintf(stderr, "actual_pair: %s, %d\n", key.c_str(), value);
+            fprintf(stderr, "expected_pair != actual_pair\n");
+        }
+
+        assert(expected_pair.first == key && expected_pair.second == value);
+
+        index++;
+    }
+    expected_rules = {
+        {LLAMA_GRETYPE_RULE_REF, 5},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 2},
+        {LLAMA_GRETYPE_CHAR, 61},
+        {LLAMA_GRETYPE_RULE_REF, 3},
+        {LLAMA_GRETYPE_RULE_REF, 4},
+        {LLAMA_GRETYPE_CHAR, 10},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 4},
+        {LLAMA_GRETYPE_RULE_REF, 7},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 12},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 8},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_RULE_REF, 9},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_CHAR, 40},
+        {LLAMA_GRETYPE_RULE_REF, 3},
+        {LLAMA_GRETYPE_RULE_REF, 2},
+        {LLAMA_GRETYPE_CHAR, 41},
+        {LLAMA_GRETYPE_RULE_REF, 3},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 1},
+        {LLAMA_GRETYPE_RULE_REF, 5},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_RULE_REF, 1},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_CHAR, 45},
+        {LLAMA_GRETYPE_CHAR_ALT, 43},
+        {LLAMA_GRETYPE_CHAR_ALT, 42},
+        {LLAMA_GRETYPE_CHAR_ALT, 47},
+        {LLAMA_GRETYPE_RULE_REF, 4},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 6},
+        {LLAMA_GRETYPE_RULE_REF, 7},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_CHAR, 97},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
+        {LLAMA_GRETYPE_RULE_REF, 10},
+        {LLAMA_GRETYPE_RULE_REF, 3},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_RULE_REF, 11},
+        {LLAMA_GRETYPE_RULE_REF, 3},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_CHAR, 97},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
+        {LLAMA_GRETYPE_CHAR_ALT, 48},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+        {LLAMA_GRETYPE_CHAR_ALT, 95},
+        {LLAMA_GRETYPE_RULE_REF, 10},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_CHAR, 48},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+        {LLAMA_GRETYPE_RULE_REF, 11},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_CHAR, 48},
+        {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+        {LLAMA_GRETYPE_END, 0},
+        {LLAMA_GRETYPE_CHAR, 32},
+        {LLAMA_GRETYPE_CHAR_ALT, 9},
+        {LLAMA_GRETYPE_CHAR_ALT, 10},
+        {LLAMA_GRETYPE_RULE_REF, 12},
+        {LLAMA_GRETYPE_ALT, 0},
+        {LLAMA_GRETYPE_END, 0},
+    };
+
+    index = 0;
+    for (auto rule : parsed_grammar.rules)
+    {
+        // compare rule to expected rule
+        for (uint32_t i = 0; i < rule.size(); i++)
+        {
+            llama_grammar_element element = rule[i];
+            llama_grammar_element expected_element = expected_rules[index];
+
+            // pretty print error message before asserting
+            if (expected_element.type != element.type || expected_element.value != element.value)
+            {
+                fprintf(stderr, "index: %d\n", index);
+                fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value);
+                fprintf(stderr, "actual_element: %d, %d\n", element.type, element.value);
+                fprintf(stderr, "expected_element != actual_element\n");
+            }
+
+            assert(expected_element.type == element.type && expected_element.value == element.value);
+            index++;
+        }
+    }
+
+    return 0;
+}