]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
tests : adds simple llama grammar tests (#2618)
authordrbh <redacted>
Thu, 17 Aug 2023 07:41:01 +0000 (03:41 -0400)
committerGitHub <redacted>
Thu, 17 Aug 2023 07:41:01 +0000 (10:41 +0300)
* adds simple llama grammar tests

* fix lint and add Makefile

* 0 terminate code_points

* avoid dangling pointers in candidate cleanup

* cleanup grammar at end of test

Makefile
tests/CMakeLists.txt
tests/test-llama-grammar.cpp [new file with mode: 0644]

index 5b801d16f5ecde9a222373451052bc66e2c88873..376a091dc3ded9fd5e759bc054ecb854d8599478 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -2,7 +2,7 @@
 BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test
 
 # Binaries only useful for tests
-TEST_TARGETS = tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
+TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
 
 default: $(BUILD_TARGETS)
 
@@ -412,6 +412,9 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o
 vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS)
        $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
 
+tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o llama.o common.o $(OBJS)
+       $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
+
 tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS)
        $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS)
 
index 689fb6f2afe5a1b73d3251b987d8ad024509ca25..276f39b3b7ea431ad3e0e356afb89ae15371edfa 100644 (file)
@@ -12,5 +12,6 @@ llama_add_test(test-quantize-perf.cpp)
 llama_add_test(test-sampling.cpp)
 llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
 llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp)
+llama_add_test(test-llama-grammar.cpp  ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/common.cpp)
 llama_add_test(test-grad0.cpp) # SLOW
 # llama_add_test(test-opt.cpp) # SLOW
diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp
new file mode 100644 (file)
index 0000000..f98c653
--- /dev/null
@@ -0,0 +1,403 @@
+#ifdef NDEBUG
+#undef NDEBUG
+#endif
+
+#include "llama.cpp"
+#include "examples/common.cpp"
+#include "examples/grammar-parser.cpp"
+#include <cassert>
+
+int main()
+{
+    grammar_parser::parse_state parsed_grammar;
+
+    std::vector<std::pair<std::string, uint32_t>> expected = {
+        {"expr", 2},
+        {"expr_6", 6},
+        {"expr_7", 7},
+        {"ident", 8},
+        {"ident_10", 10},
+        {"num", 9},
+        {"num_11", 11},
+        {"root", 0},
+        {"root_1", 1},
+        {"root_5", 5},
+        {"term", 4},
+        {"ws", 3},
+        {"ws_12", 12},
+    };
+
+    std::vector<std::vector<llama_grammar_element>> expected_rules = {
+        {{LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_RULE_REF, 2},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_RULE_REF, 4},
+            {LLAMA_GRETYPE_CHAR, 10},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {{LLAMA_GRETYPE_RULE_REF, 4}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_END, 0}},
+        {{LLAMA_GRETYPE_RULE_REF, 12}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_RULE_REF, 8},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_RULE_REF, 9},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_CHAR, 40},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_RULE_REF, 2},
+            {LLAMA_GRETYPE_CHAR, 41},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {{LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_RULE_REF, 5}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_RULE_REF, 1}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_CHAR, 45},
+            {LLAMA_GRETYPE_CHAR_ALT, 43},
+            {LLAMA_GRETYPE_CHAR_ALT, 42},
+            {LLAMA_GRETYPE_CHAR_ALT, 47},
+            {LLAMA_GRETYPE_RULE_REF, 4},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {{LLAMA_GRETYPE_RULE_REF, 6}, {LLAMA_GRETYPE_RULE_REF, 7}, {LLAMA_GRETYPE_ALT, 0}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_CHAR, 97},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
+            {LLAMA_GRETYPE_RULE_REF, 10},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {{LLAMA_GRETYPE_RULE_REF, 11}, {LLAMA_GRETYPE_RULE_REF, 3}, {LLAMA_GRETYPE_END, 0}},
+        {
+            {LLAMA_GRETYPE_CHAR, 97},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 122},
+            {LLAMA_GRETYPE_CHAR_ALT, 48},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+            {LLAMA_GRETYPE_CHAR_ALT, 95},
+            {LLAMA_GRETYPE_RULE_REF, 10},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 48},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+            {LLAMA_GRETYPE_RULE_REF, 11},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_CHAR, 48},
+            {LLAMA_GRETYPE_CHAR_RNG_UPPER, 57},
+            {LLAMA_GRETYPE_END, 0},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 32},
+            {LLAMA_GRETYPE_CHAR_ALT, 9},
+            {LLAMA_GRETYPE_CHAR_ALT, 10},
+            {LLAMA_GRETYPE_RULE_REF, 12},
+            {LLAMA_GRETYPE_ALT, 0},
+            {LLAMA_GRETYPE_END, 0},
+        },
+    };
+
+    for (auto pair : expected)
+    {
+        parsed_grammar.symbol_ids[pair.first] = pair.second;
+    }
+
+    for (auto rule : expected_rules)
+    {
+        parsed_grammar.rules.push_back({});
+        for (auto element : rule)
+        {
+            parsed_grammar.rules.back().push_back(element);
+        }
+    }
+
+    llama_grammar *grammar = NULL;
+    std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
+    grammar = llama_grammar_init(
+        grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+
+    std::vector<std::vector<llama_grammar_element>> expected_stacks = {
+        {
+            {LLAMA_GRETYPE_RULE_REF, 5},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_CHAR, 97},
+        },
+        {
+            {LLAMA_GRETYPE_RULE_REF, 5},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_CHAR, 48},
+        },
+        {
+            {LLAMA_GRETYPE_RULE_REF, 5},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_CHAR, 48},
+        },
+        {
+            {LLAMA_GRETYPE_RULE_REF, 5},
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_CHAR, 40},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_CHAR, 97},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_CHAR, 48},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_RULE_REF, 3},
+            {LLAMA_GRETYPE_CHAR, 48},
+        },
+        {
+            {LLAMA_GRETYPE_CHAR, 61},
+            {LLAMA_GRETYPE_RULE_REF, 7},
+            {LLAMA_GRETYPE_CHAR, 40},
+        }};
+
+    auto index = 0;
+    for (auto stack : grammar->stacks)
+    {
+        // compare stack to expected_stack
+        for (uint32_t i = 0; i < stack.size(); i++)
+        {
+            auto element = stack[i];
+            auto expected_element = expected_stacks[index][i];
+
+            // pretty print error message before asserting
+            if (expected_element.type != element->type || expected_element.value != element->value)
+            {
+                fprintf(stderr, "index: %d\n", index);
+                fprintf(stderr, "expected_element: %d, %d\n", expected_element.type, expected_element.value);
+                fprintf(stderr, "actual_element: %d, %d\n", element->type, element->value);
+                fprintf(stderr, "expected_element != actual_element\n");
+            }
+
+            assert(expected_element.type == element->type && expected_element.value == element->value);
+        }
+        index++;
+    }
+
+    std::vector<std::vector<const llama_grammar_element *>> next_stacks;
+    std::vector<llama_grammar_candidate> next_candidates;
+    next_candidates.resize(24);
+
+    for (size_t i = 0; i < 24; ++i)
+    {
+        uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
+        cp[0] = 37 + i;
+        cp[1] = 0;
+        next_candidates[i] = {i, cp};
+    }
+
+    std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {11, 48},
+            {12, 49},
+            {13, 50},
+            {14, 51},
+            {15, 52},
+            {16, 53},
+            {17, 54},
+            {18, 55},
+            {19, 56},
+            {20, 57},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {11, 48},
+            {12, 49},
+            {13, 50},
+            {14, 51},
+            {15, 52},
+            {16, 53},
+            {17, 54},
+            {18, 55},
+            {19, 56},
+            {20, 57},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {11, 48},
+            {12, 49},
+            {13, 50},
+            {14, 51},
+            {15, 52},
+            {16, 53},
+            {17, 54},
+            {18, 55},
+            {19, 56},
+            {20, 57},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {3, 40},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+        {
+            {0, 37},
+            {1, 38},
+            {2, 39},
+            {4, 41},
+            {5, 42},
+            {6, 43},
+            {7, 44},
+            {8, 45},
+            {9, 46},
+            {10, 47},
+            {11, 48},
+            {12, 49},
+            {13, 50},
+            {14, 51},
+            {15, 52},
+            {16, 53},
+            {17, 54},
+            {18, 55},
+            {19, 56},
+            {20, 57},
+            {21, 58},
+            {22, 59},
+            {23, 60},
+        },
+    };
+
+    std::vector<llama_grammar_candidate> rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[0], next_candidates);
+
+    std::vector<std::vector<llama_grammar_candidate>> all_rejects;
+
+    for (std::size_t count = 0; count < grammar->stacks.size(); ++count)
+    {
+        rejects = llama_grammar_reject_candidates_for_stack(grammar->rules, grammar->stacks[count], next_candidates);
+        all_rejects.push_back(rejects);
+    }
+
+    index = 0;
+    for (auto rej : all_rejects)
+    {
+        for (uint32_t i = 0; i < rej.size(); i++)
+        {
+            auto element = rej[i];
+            auto expected_element = expected_reject[index][i];
+            assert(element.index == expected_element.first && *element.code_points == expected_element.second);
+        }
+        index++;
+    }
+
+    for (auto &candidate : next_candidates)
+    {
+        delete[] candidate.code_points;
+        candidate.code_points = nullptr;
+    }
+    delete grammar;
+    return 0;
+}