]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sampling : support for llguidance grammars (#10224)
authorMichaƂ Moskal <redacted>
Sun, 2 Feb 2025 07:55:32 +0000 (23:55 -0800)
committerGitHub <redacted>
Sun, 2 Feb 2025 07:55:32 +0000 (09:55 +0200)
* initial porting of previous LLG patch

* update for new APIs

* build: integrate llguidance as an external project

* use '%llguidance' as marker to enable llg lark syntax

* add some docs

* clarify docs

* code style fixes

* remove llguidance.h from .gitignore

* fix tests when llg is enabled

* pass vocab not model to llama_sampler_init_llg()

* copy test-grammar-integration.cpp to test-llguidance.cpp

* clang fmt

* fix ref-count bug

* build and run test

* gbnf -> lark syntax

* conditionally include llguidance test based on LLAMA_LLGUIDANCE flag

* rename llguidance test file to test-grammar-llguidance.cpp

* add gh action for llg test

* align tests with LLG grammar syntax and JSON Schema spec

* llama_tokenizer() in fact requires valid utf8

* update llg

* format file

* add $LLGUIDANCE_LOG_LEVEL support

* fix whitespace

* fix warning

* include <cmath> for INFINITY

* add final newline

* fail llama_sampler_init_llg() at runtime

* Link gbnf_to_lark.py script; fix links; refer to llg docs for lexemes

* simplify #includes

* improve doc string for LLAMA_LLGUIDANCE

* typo in merge

* bump llguidance to 0.6.12

13 files changed:
.github/workflows/build.yml
CMakeLists.txt
common/CMakeLists.txt
common/json-schema-to-grammar.cpp
common/json-schema-to-grammar.h
common/llguidance.cpp [new file with mode: 0644]
common/sampling.cpp
common/sampling.h
docs/llguidance.md [new file with mode: 0644]
tests/CMakeLists.txt
tests/test-grammar-integration.cpp
tests/test-grammar-llguidance.cpp [new file with mode: 0644]
tests/test-json-schema-to-grammar.cpp

index 7392f2bfe6eaf496ae0bad30ef6fcfc35ebd44f5..8f9c82f87403e6108764463d102be866cff13b1a 100644 (file)
@@ -302,6 +302,36 @@ jobs:
           cd build
           ctest -L main --verbose --timeout 900
 
+  ubuntu-latest-llguidance:
+    runs-on: ubuntu-latest
+
+    steps:
+      - name: Clone
+        id: checkout
+        uses: actions/checkout@v4
+
+      - name: Dependencies
+        id: depends
+        run: |
+          sudo apt-get update
+          sudo apt-get install build-essential
+
+      - name: Build
+        id: cmake_build
+        run: |
+          mkdir build
+          cd build
+          cmake .. \
+            -DLLAMA_FATAL_WARNINGS=ON \
+            -DLLAMA_LLGUIDANCE=ON
+          cmake --build . --config Release -j $(nproc)
+
+      - name: Test
+        id: cmake_test
+        run: |
+          cd build
+          ctest -L main --verbose --timeout 900
+
   ubuntu-latest-cmake-rpc:
     runs-on: ubuntu-latest
 
index 4c62d17880ab221b07694fd6fd5ec3e502af9baf..74b48d24d06bf93a4b09fa9b612bf7e662c0f5ca 100644 (file)
@@ -80,6 +80,7 @@ option(LLAMA_BUILD_SERVER   "llama: build server example" ${LLAMA_STANDALONE})
 
 # 3rd party libs
 option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
+option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)
 
 # Required for relocatable CMake package
 include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
index 72f0915c12524a046ad88f02b174c3529860e52c..e61015d2ad7f9515ca8469f2fc1e48a847d0eea8 100644 (file)
@@ -65,6 +65,7 @@ add_library(${TARGET} STATIC
     console.h
     json-schema-to-grammar.cpp
     json.hpp
+    llguidance.cpp
     log.cpp
     log.h
     minja.hpp
@@ -91,6 +92,33 @@ if (LLAMA_CURL)
     set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
 endif ()
 
+if (LLAMA_LLGUIDANCE)
+    include(ExternalProject)
+    set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
+    set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
+    ExternalProject_Add(llguidance_ext
+        GIT_REPOSITORY https://github.com/guidance-ai/llguidance
+        # v0.6.12:
+        GIT_TAG ced1c9023d47ec194fa977932d35ce65c2ebfc09
+        PREFIX ${CMAKE_BINARY_DIR}/llguidance
+        SOURCE_DIR ${LLGUIDANCE_SRC}
+        BUILD_IN_SOURCE TRUE
+        CONFIGURE_COMMAND ""
+        BUILD_COMMAND cargo build --release
+        INSTALL_COMMAND ""
+        BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h
+        UPDATE_COMMAND ""
+    )
+    target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)
+
+    add_library(llguidance STATIC IMPORTED)
+    set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a)
+    add_dependencies(llguidance llguidance_ext)
+
+    target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
+    set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance)
+endif ()
+
 target_include_directories(${TARGET} PUBLIC .)
 target_compile_features   (${TARGET} PUBLIC cxx_std_17)
 target_link_libraries     (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
index 1f47e313edecc404eb32db407d9319ad557f8239..3ebcc3d9fbbf8fe718c66a4f8a348b54bbae8f1a 100644 (file)
@@ -991,7 +991,14 @@ public:
     }
 };
 
-std::string json_schema_to_grammar(const json & schema) {
+std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
+#ifdef LLAMA_USE_LLGUIDANCE
+    if (!force_gbnf) {
+        return "%llguidance {}\nstart: %json " + schema.dump();
+    }
+#else
+    (void)force_gbnf;
+#endif // LLAMA_USE_LLGUIDANCE
     return build_grammar([&](const common_grammar_builder & callbacks) {
         auto copy = schema;
         callbacks.resolve_refs(copy);
index ba4112cb9b02dfdea78fd2e851604fd821e04655..62a3b0a4477cca0f9558183cec9832ed2c421973 100644 (file)
@@ -5,7 +5,8 @@
 #define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
 
-std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
+std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
+                                   bool force_gbnf = false);
 
 struct common_grammar_builder {
     std::function<std::string(const std::string &, const std::string &)> add_rule;
diff --git a/common/llguidance.cpp b/common/llguidance.cpp
new file mode 100644 (file)
index 0000000..7aa8ddd
--- /dev/null
@@ -0,0 +1,270 @@
+#include "sampling.h"
+#include "log.h"
+
+#ifdef LLAMA_USE_LLGUIDANCE
+
+#    include "llguidance.h"
+#    include <cmath>
+
+struct llama_sampler_llg {
+    const llama_vocab * vocab;
+    std::string         grammar_kind;
+    std::string         grammar_data;
+    LlgTokenizer *      tokenizer;
+    LlgConstraint *     grammar;
+    LlgMaskResult       llg_res;
+    bool                has_llg_res;
+};
+
+static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
+                                             const char * grammar_data) {
+    LlgConstraintInit cinit;
+    llg_constraint_init_set_defaults(&cinit, tokenizer);
+    const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
+    if (log_level && *log_level) {
+        cinit.log_stderr_level = atoi(log_level);
+    }
+    auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
+    if (llg_get_error(c)) {
+        LOG_ERR("llg error: %s\n", llg_get_error(c));
+        llg_free_constraint(c);
+        return nullptr;
+    }
+    return c;
+}
+
+static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
+    return "llguidance";
+}
+
+static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
+    auto * ctx = (llama_sampler_llg *) smpl->ctx;
+    if (ctx->grammar) {
+        LlgCommitResult res;
+        llg_commit_token(ctx->grammar, token, &res);
+        ctx->has_llg_res = false;
+    }
+}
+
+static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_llg *) smpl->ctx;
+    if (ctx->grammar) {
+        if (!ctx->has_llg_res) {
+            if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
+                ctx->has_llg_res = true;
+            } else {
+                LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
+                llg_free_constraint(ctx->grammar);
+                ctx->grammar = nullptr;
+            }
+        }
+        if (ctx->has_llg_res) {
+            if (ctx->llg_res.is_stop) {
+                for (size_t i = 0; i < cur_p->size; ++i) {
+                    if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
+                        cur_p->data[i].logit = -INFINITY;
+                    }
+                }
+            } else {
+                const uint32_t * mask = ctx->llg_res.sample_mask;
+                for (size_t i = 0; i < cur_p->size; ++i) {
+                    auto token = cur_p->data[i].id;
+                    if ((mask[token / 32] & (1 << (token % 32))) == 0) {
+                        cur_p->data[i].logit = -INFINITY;
+                    }
+                }
+            }
+        }
+    }
+}
+
+static void llama_sampler_llg_reset(llama_sampler * smpl) {
+    auto * ctx = (llama_sampler_llg *) smpl->ctx;
+    if (!ctx->grammar) {
+        return;
+    }
+
+    auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
+    llg_free_constraint(ctx->grammar);
+    ctx->grammar     = grammar_new;
+    ctx->has_llg_res = false;
+}
+
+static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_llg *) smpl->ctx;
+
+    auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);
+
+    // copy the state
+    {
+        auto * result_ctx = (llama_sampler_llg *) result->ctx;
+
+        if (ctx->grammar) {
+            result_ctx->grammar_kind = ctx->grammar_kind;
+            result_ctx->grammar_data = ctx->grammar_data;
+            result_ctx->grammar      = llg_clone_constraint(ctx->grammar);
+            result_ctx->tokenizer    = llg_clone_tokenizer(ctx->tokenizer);
+        }
+    }
+
+    return result;
+}
+
+static void llama_sampler_llg_free(llama_sampler * smpl) {
+    const auto * ctx = (llama_sampler_llg *) smpl->ctx;
+
+    if (ctx->grammar) {
+        llg_free_constraint(ctx->grammar);
+        llg_free_tokenizer(ctx->tokenizer);
+    }
+
+    delete ctx;
+}
+
+static llama_sampler_i llama_sampler_llg_i = {
+    /* .name   = */ llama_sampler_llg_name,
+    /* .accept = */ llama_sampler_llg_accept_impl,
+    /* .apply  = */ llama_sampler_llg_apply,
+    /* .reset  = */ llama_sampler_llg_reset,
+    /* .clone  = */ llama_sampler_llg_clone,
+    /* .free   = */ llama_sampler_llg_free,
+};
+
+static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
+                                            uint32_t * output_tokens, size_t output_tokens_len) {
+    const llama_vocab * vocab = (const llama_vocab *) user_data;
+    int                 r     = 0;
+    try {
+        r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
+                           true);
+    } catch (const std::exception & e) {
+        GGML_ABORT("llama_tokenize failed: %s\n", e.what());
+    }
+    if (r < 0) {
+        return -r;
+    }
+    return r;
+}
+
+static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
+    // TODO store the tokenizer in the vocab somehow
+    static const llama_vocab * vocab_cache;
+    static LlgTokenizer *      tokenizer_cache;
+
+    if (vocab_cache == vocab) {
+        return llg_clone_tokenizer(tokenizer_cache);
+    }
+
+    auto tok_eos = llama_vocab_eot(vocab);
+    if (tok_eos == LLAMA_TOKEN_NULL) {
+        tok_eos = llama_vocab_eos(vocab);
+    }
+
+    size_t vocab_size = llama_vocab_n_tokens(vocab);
+
+    auto token_lens       = new uint32_t[vocab_size];
+    // we typically have ~7 bytes per token; let's go on the safe side here
+    auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
+    auto token_bytes      = new uint8_t[token_bytes_size];
+
+    size_t offset = 0;
+    for (size_t i = 0; i < vocab_size; i++) {
+        size_t max_token = 1024;
+        if (token_bytes_size - offset < max_token) {
+            GGML_ABORT("token_bytes buffer too small\n");
+        }
+
+        llama_token token = i;
+        auto        dp    = (char *) token_bytes + offset;
+        auto        size  = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
+        if (size < 0) {
+            GGML_ABORT("llama_detokenize failed\n");
+        }
+        if (size == 0) {
+            size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
+            if (size < 0) {
+                GGML_ABORT("llama_detokenize failed\n");
+            }
+            if (size != 0) {
+                *dp = '\xff';  // special token prefix marker
+                size += 1;
+            }
+        }
+
+        token_lens[i] = size;
+        offset += size;
+    }
+
+    LlgTokenizerInit tinit = {
+        /* .vocab_size                         = */ (uint32_t) vocab_size,
+        /* .tok_eos                            = */ (uint32_t) tok_eos,
+        /* .token_lens                         = */ token_lens,
+        /* .token_bytes                        = */ token_bytes,
+        /* .tokenizer_json                     = */ nullptr,
+        /* .tokenize_assumes_string            = */ true,
+        /* .tokenize_fn                        = */ llama_sampler_llg_tokenize_fn,
+        /* .use_approximate_greedy_tokenize_fn = */ false,
+        /* .tokenize_user_data                 = */ vocab,
+    };
+
+    char           error_buffer[1024];
+    LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));
+
+    delete[] token_bytes;
+    delete[] token_lens;
+
+    if (tokenizer == nullptr) {
+        LOG_ERR("llg tokenizer error: %s\n", error_buffer);
+        return tokenizer;
+    }
+
+    if (tokenizer_cache) {
+        llg_free_tokenizer(tokenizer_cache);
+    }
+    vocab_cache     = vocab;
+    tokenizer_cache = tokenizer;
+
+    return llg_clone_tokenizer(tokenizer_cache);
+}
+
+llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
+                                       const char * grammar_data) {
+    auto * ctx = new llama_sampler_llg;
+
+    if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
+        auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
+        *ctx           = {
+            /* .vocab        = */ vocab,
+            /* .grammar_kind = */ grammar_kind,
+            /* .grammar_data = */ grammar_data,
+            /* .tokenizer    = */ tokenizer,
+            /* .grammar      = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
+            /* .llg_res      = */ {},
+            /* .has_llg_res  = */ false,
+        };
+    } else {
+        *ctx = {
+            /* .vocab        = */ vocab,
+            /* .grammar_kind = */ {},
+            /* .grammar_data = */ {},
+            /* .tokenizer    = */ nullptr,
+            /* .grammar      = */ nullptr,
+            /* .llg_res      = */ {},
+            /* .has_llg_res  = */ false,
+        };
+    }
+
+    return new llama_sampler{
+        /* .iface = */ &llama_sampler_llg_i,
+        /* .ctx   = */ ctx,
+    };
+}
+
+#else
+
+llama_sampler * llama_sampler_init_llg(const llama_vocab *, const char *, const char *) {
+    LOG_WRN("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
+    return nullptr;
+}
+
+#endif  // LLAMA_USE_LLGUIDANCE
index bc7e49fdb27223cdfb5629338771f0cc036b02bd..e4b21ca1011dddbd1da2a6f8ba6013beaee43605 100644 (file)
@@ -156,13 +156,25 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
     for (const auto & str : params.grammar_trigger_words) {
         trigger_words.push_back(str.word.c_str());
     }
+
+    struct llama_sampler * grmr;
+    if (params.grammar.compare(0, 11, "%llguidance") == 0) {
+#ifdef LLAMA_USE_LLGUIDANCE
+        grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
+#else
+        GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
+#endif // LLAMA_USE_LLGUIDANCE
+    } else {
+        grmr = params.grammar_lazy
+             ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
+                                               trigger_words.data(), trigger_words.size(),
+                                               params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
+             :      llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
+    }
+
     auto * result = new common_sampler {
         /* .params = */ params,
-        /* .grmr   = */ params.grammar_lazy
-            ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
-                                              trigger_words.data(), trigger_words.size(),
-                                              params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
-            :      llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"),
+        /* .grmr   = */ grmr,
         /* .chain  = */ llama_sampler_chain_init(lparams),
         /* .prev   = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
         /* .cur    = */ {},
index 348911b18888b263e0dd90a9f26e56a2308e1846..2064421db4e80237e0dc8cc3acbcef4894000416 100644 (file)
@@ -102,3 +102,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
 
 std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
 std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
+
+llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
+                const char * grammar_kind, const char * grammar_data);
diff --git a/docs/llguidance.md b/docs/llguidance.md
new file mode 100644 (file)
index 0000000..792d207
--- /dev/null
@@ -0,0 +1,51 @@
+# LLGuidance Support in llama.cpp
+
+[LLGuidance](https://github.com/guidance-ai/llguidance) is a library for constrained decoding (also called constrained sampling or structured outputs) for Large Language Models (LLMs). Initially developed as the backend for the [Guidance](https://github.com/guidance-ai/guidance) library, it can also be used independently.
+
+LLGuidance supports JSON Schemas and arbitrary context-free grammars (CFGs) written in a [variant](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md) of Lark syntax. It is [very fast](https://github.com/guidance-ai/jsonschemabench/tree/main/maskbench) and has [excellent](https://github.com/guidance-ai/llguidance/blob/main/docs/json_schema.md) JSON Schema coverage but requires the Rust compiler, which complicates the llama.cpp build process.
+
+## Building
+
+To enable LLGuidance support, build llama.cpp with the `LLAMA_LLGUIDANCE` option:
+
+```sh
+cmake -B build -DLLAMA_LLGUIDANCE=ON
+make -C build -j
+```
+
+This requires the Rust compiler and the `cargo` tool to be [installed](https://www.rust-lang.org/tools/install).
+
+## Interface
+
+There are no new command-line arguments or modifications to `common_params`. When enabled, grammars starting with `%llguidance` are passed to LLGuidance instead of the [current](../grammars/README.md) llama.cpp grammars. Additionally, JSON Schema requests (e.g., using the `-j` argument in `llama-cli`) are also passed to LLGuidance.
+
+For your existing GBNF grammars, you can use [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) to convert them to LLGuidance Lark-like format.
+
+## Performance
+
+Computing a "token mask" (i.e., the set of allowed tokens) for a llama3 tokenizer with 128k tokens takes, on average, 50ÎŒs of single-core CPU time for the [JSON Schema Bench](https://github.com/guidance-ai/jsonschemabench). The p99 time is 0.5ms, and the p100 time is 20ms. These results are due to the lexer/parser split and several [optimizations](https://github.com/guidance-ai/llguidance/blob/main/docs/optimizations.md).
+
+## JSON Schema
+
+LLGuidance adheres closely to the JSON Schema specification. For example:
+
+- `additionalProperties` defaults to `true`, unlike current grammars, though you can set `"additionalProperties": false` if needed.
+- any whitespace is allowed.
+- The definition order in the `"properties": {}` object is maintained, regardless of whether properties are required (current grammars always puts required properties first).
+
+Unsupported schemas result in an error message—no keywords are silently ignored.
+
+## Why Not Reuse GBNF Format?
+
+GBNF lacks the concept of a lexer.
+
+Most programming languages, including JSON, use a two-step process: a lexer (built with regular expressions) converts a byte stream into lexemes, which are then processed by a CFG parser. This approach is faster because lexers are cheaper to evaluate, and there is ~10x fewer lexemes than bytes.
+LLM tokens often align with lexemes, so the parser is engaged in under 0.5% of tokens, with the lexer handling the rest.
+
+However, the user has to provide the distinction between lexemes and CFG symbols. In [Lark](https://github.com/lark-parser/lark), lexeme names are uppercase, while CFG symbols are lowercase.
+The [gbnf_to_lark.py script](https://github.com/guidance-ai/llguidance/blob/main/scripts/gbnf_to_lark.py) can often take care of this automatically.
+See [LLGuidance syntax docs](https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#terminals-vs-rules) for more details.
+
+## Error Handling
+
+Errors are currently printed to `stderr`, and generation continues. Improved error handling may be added in the future.
index 40f83ff0d513d3fc89c3ceebb8f2ddf59db3036e..7a158d6024d78dcd2df61a6027c125cfd17e8a15 100644 (file)
@@ -86,6 +86,9 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2             ARGS ${CMAKE
 llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact            ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
 llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder         ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
 
+if (LLAMA_LLGUIDANCE)
+    llama_target_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
+endif ()
 
 if (NOT WIN32)
     # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API
index 288e08f51856c14c4726db96f8070f12f489e2c7..89060864894a4a4e648fb463eaf975fdfd96fa7a 100644 (file)
@@ -129,7 +129,7 @@ static void test_grammar(const std::string & test_desc, const std::string & gram
     test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
 }
 static void test_schema(const std::string & test_desc, const std::string & schema_str, const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
-    test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str)), passing_strings, failing_strings);
+    test(test_desc + ". Schema: " + schema_str, json_schema_to_grammar(json::parse(schema_str), true), passing_strings, failing_strings);
 }
 
 static void test_simple_grammar() {
diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp
new file mode 100644 (file)
index 0000000..8b69600
--- /dev/null
@@ -0,0 +1,1140 @@
+#ifdef NDEBUG
+#    undef NDEBUG
+#endif
+
+#include "unicode.h"
+#include "sampling.h"
+
+#include <cassert>
+#include <string>
+#include <vector>
+
+static const llama_vocab * vocab;
+
+static bool match_string(const std::string & input, llama_sampler * grammar) {
+    llama_sampler_reset(grammar);
+    auto tokens = common_tokenize(vocab, input, false, false);
+
+    auto n_vocab = llama_vocab_n_tokens(vocab);
+
+    std::vector<llama_token_data> cur;
+    cur.reserve(n_vocab);
+    for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+        cur.emplace_back(llama_token_data{ token_id, 0.0f, 0.0f });
+    }
+    auto tok_arr = llama_token_data_array{ cur.data(), cur.size(), -1, false };
+
+    for (const auto token : tokens) {
+        for (llama_token token_id = 0; token_id < (llama_token) n_vocab; token_id++) {
+            cur[token_id].logit = 0.0f;
+        }
+        llama_sampler_apply(grammar, &tok_arr);
+        if (cur[token].logit < 0.0f) {
+            return false;
+        }
+        llama_sampler_accept(grammar, token);
+    }
+
+    // do we allow EOS at the end? if so the grammar is accepting
+
+    auto tok_eos = llama_vocab_eot(vocab);
+    if (tok_eos == LLAMA_TOKEN_NULL) {
+        tok_eos = llama_vocab_eos(vocab);
+    }
+
+    cur[tok_eos].logit = 0.0f;
+    llama_sampler_apply(grammar, &tok_arr);
+
+    return cur[tok_eos].logit >= 0.0f;
+}
+
+static void test(const std::string & test_desc, const std::string & grammar_str,
+                 const std::vector<std::string> & passing_strings, const std::vector<std::string> & failing_strings) {
+    fprintf(stderr, "⚫ Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str());
+    fflush(stderr);
+
+    auto * grammar = llama_sampler_init_llg(vocab, "lark", grammar_str.c_str());
+
+    fprintf(stderr, "  đŸ”” Valid strings:\n");
+
+    // Passing strings
+    for (const auto & test_string : passing_strings) {
+        fprintf(stderr, "    \"%s\" ", test_string.c_str());
+        fflush(stderr);
+
+        bool matched = match_string(test_string, grammar);
+
+        if (!matched) {
+            fprintf(stderr, "❌ (failed to match)\n");
+
+            // DEBUG: Write strings to files so that we can analyze more easily with gbnf-validator program to see exactly where things failed.
+            // DEBUG: Write the grammar_str to test-grammar-integration.grammar.gbnf
+            FILE * grammar_file = fopen("test-grammar-integration.grammar.gbnf", "w");
+            if (grammar_file) {
+                fprintf(grammar_file, "%s", grammar_str.c_str());
+                fclose(grammar_file);
+            }
+
+            // DEBUG: Write the test string to test-grammar-integration.string.txt
+            FILE * string_file = fopen("test-grammar-integration.string.txt", "w");
+            if (string_file) {
+                fprintf(string_file, "%s", test_string.c_str());
+                fclose(string_file);
+            }
+
+            fprintf(stderr,
+                    "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
+                    "command:     ./llama-gbnf-validator test-grammar-integration.grammar.gbnf "
+                    "test-grammar-integration.string.txt\n\n");
+        } else {
+            fprintf(stdout, "✅\n");
+        }
+
+        assert(matched);
+    }
+
+    fprintf(stderr, "  đŸŸ  Invalid strings:\n");
+
+    // Failing strings
+    for (const auto & test_string : failing_strings) {
+        fprintf(stderr, "    \"%s\" ", test_string.c_str());
+        fflush(stderr);
+
+        bool matched = match_string(test_string, grammar);
+
+        if (matched) {
+            fprintf(stderr, "❌ (incorrectly matched)\n");
+        } else {
+            fprintf(stdout, "✅\n");
+        }
+        assert(!matched);
+    }
+
+    llama_sampler_free(grammar);
+}
+
+static void test_grammar(const std::string & test_desc, const std::string & grammar_str,
+                         const std::vector<std::string> & passing_strings,
+                         const std::vector<std::string> & failing_strings) {
+    test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings);
+}
+
+static void test_schema(const std::string & test_desc, const std::string & schema_str,
+                        const std::vector<std::string> & passing_strings,
+                        const std::vector<std::string> & failing_strings) {
+    test(test_desc + ". Schema: " + schema_str, "%llguidance {}\nstart: %json " + schema_str, passing_strings,
+         failing_strings);
+}
+
+static void test_simple_grammar() {
+    test_schema("min 0",
+                R"""({
+            "type": "integer",
+            "minimum": 0
+        })""",
+                // Passing strings
+                {
+                    "0",
+                    "10",
+                    "12",
+                    "10000",
+                },
+                // Failing strings
+                {
+                    "-1",
+                    "-10",
+                    "-10000",
+                    "-100000000000000000000000000000000",
+                    // "100000000000000000000000000000000",
+                    "00",
+                    "01",
+                    "-0",
+                });
+    test_schema("min 2",
+                // Schema
+                R"""({
+            "type": "integer",
+            "minimum": 2
+        })""",
+                // Passing strings
+                {
+                    "2",
+                    "3",
+                    "4",
+                    "10",
+                    "20",
+                    "1234567890000000",
+                },
+                // Failing strings
+                {
+                    "0", "1", "-1", "-100", "0", "1", "01", "02",
+                    // "12345678900000000",
+                });
+    test_schema("min 456",
+                R"""({
+            "type": "integer",
+            "minimum": 456
+        })""",
+                // Passing strings
+                {
+                    "456",
+                    "4560",
+                    "457",
+                    "460",
+                    "500",
+                },
+                // Failing strings
+                {
+                    "455",
+                    "356",
+                    "50",
+                    "050",
+                    "-1",
+                    "-456",
+                });
+    test_schema("min -123",
+                R"""({
+            "type": "integer",
+            "minimum": -123
+        })""",
+                // Passing strings
+                {
+                    "-123",
+                    "-122",
+                    "-11",
+                    "-1",
+                    "0",
+                    "1",
+                    "123",
+                    "1234",
+                    "2345",
+                },
+                // Failing strings
+                {
+                    "-1234",
+                    "-124",
+                });
+
+    test_schema("max 9999",
+                // Schema
+                R"""({
+            "type": "integer",
+            "maximum": 9999
+        })""",
+                // Passing strings
+                {
+                    "-99999",
+                    "0",
+                    "9999",
+                },
+                // Failing strings
+                {
+                    "10000",
+                    "99991",
+                });
+    test_schema("max -9999",
+                // Schema
+                R"""({
+            "type": "integer",
+            "maximum": -9999
+        })""",
+                // Passing strings
+                {
+                    "-10000",
+                    "-9999",
+                },
+                // Failing strings
+                {
+                    "-9998",
+                    "0",
+                    "9999",
+                });
+    test_schema("min 5 max 30",
+                // Schema
+                R"""({
+            "type": "integer",
+            "minimum": 5,
+            "maximum": 30
+        })""",
+                // Passing strings
+                {
+                    "5",
+                    "10",
+                    "30",
+                },
+                // Failing strings
+                {
+                    "05",
+                    "4",
+                    "-1",
+                    "31",
+                    "123",
+                    "0123",
+                });
+    test_schema("min -1 max 1",
+                R"""({
+            "type": "integer",
+            "minimum": -1,
+            "maximum": 1
+        })""",
+                // Passing strings
+                {
+                    "-1",
+                    "0",
+                    "1",
+                },
+                // Failing strings
+                {
+                    "-11",
+                    "-10",
+                    "-2",
+                    "2",
+                    "10",
+                    "11",
+                });
+    test_schema("min -123 max 42",
+                R"""({
+            "type": "integer",
+            "minimum": -123,
+            "maximum": 42
+        })""",
+                // Passing strings
+                {
+                    "-123",
+                    "-122",
+                    "-13",
+                    "-11",
+                    "-2",
+                    "-1",
+                    "0",
+                    "1",
+                    "5",
+                    "10",
+                    "39",
+                    "40",
+                    "42",
+                },
+                // Failing strings
+                {
+                    "-0123",
+                    "-124",
+                    "-1123",
+                    "-200",
+                    "43",
+                    "123",
+                    "0123",
+                });
+    test_schema("exclusive min / max",
+                // Schema
+                R"""({
+            "type": "integer",
+            "exclusiveMinimum": 0,
+            "exclusiveMaximum": 10000
+        })""",
+                // Passing strings
+                {
+                    "1",
+                    "9999",
+                },
+                // Failing strings
+                {
+                    "0",
+                    "01",
+                    "10000",
+                    "99999",
+                });
+
+    // Test case for a simple grammar
+    test_grammar("simple grammar",
+                 R"""(
+            start: expr
+            expr: term ("+" term)*
+            term: number
+            number: /[0-9]+/ )""",
+                 // Passing strings
+                 {
+                     "42",
+                     "1+2+3+4+5",
+                     "123+456",
+                 },
+                 // Failing strings
+                 {
+                     "+",
+                     "/ 3",
+                     "1+2+3+4+5+",
+                     "12a45",
+                 });
+}
+
+static void test_complex_grammar() {
+    // Test case for a more complex grammar, with both failure strings and success strings
+    test_grammar("medium complexity grammar",
+                 // Grammar
+                 R"""(
+            start: expression
+            expression: term ws (("+"|"-") ws term)*
+            term: factor ws (("*"|"/") ws factor)*
+            factor: number | variable | "(" expression ")" | function-call
+            number: /[0-9]+/
+            variable: /[a-zA-Z_][a-zA-Z0-9_]*/
+            function-call: variable ws "(" (expression ("," ws expression)*)? ")"
+            ws: /[ \t\n\r]?/ )""",
+                 // Passing strings
+                 { "42",
+                   "1*2*3*4*5",
+                   "x",
+                   "x+10",
+                   "x1+y2",
+                   "(a+b)*(c-d)",
+                   "func()",
+                   "func(x,y+2)",
+                   "a*(b+c)-d/e",
+                   "f(g(x),h(y,z))",
+                   "x + 10",
+                   "x1 + y2",
+                   "(a + b) * (c - d)",
+                   "func()",
+                   "func(x, y + 2)",
+                   "a * (b + c) - d / e",
+                   "f(g(x), h(y, z))",
+                   "123+456",
+                   "123*456*789-123/456+789*123",
+                   "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456" },
+                 // Failing strings
+                 {
+                     "+",
+                     "/ 3x",
+                     "x + + y",
+                     "a * / b",
+                     "func(,)",
+                     "func(x y)",
+                     "(a + b",
+                     "x + y)",
+                     "a + b * (c - d",
+                     "42 +",
+                     "x +",
+                     "x + 10 +",
+                     "(a + b) * (c - d",
+                     "func(",
+                     "func(x, y + 2",
+                     "a * (b + c) - d /",
+                     "f(g(x), h(y, z)",
+                     "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
+                 });
+}
+
+static void test_special_chars() {
+    // A collection of tests to exercise special characters such as "."
+    test_grammar("special characters",
+                 // Grammar
+                 R"""(
+            start: /.../ "abc" /.../
+            )""",
+                 // Passing strings
+                 { "abcabcabc", "aaaabcccc",
+                   // NOTE: Also ensures that multi-byte characters still count as a single character
+                   "đŸ””đŸŸ âœ…abcâŒđŸŸ đŸ””" },
+                 // Failing strings
+                 { "aaabcccc", "aaaaabcccc", "aaaabccc", "aaaabccccc", "đŸ””đŸŸ âœ…âŒabcâŒâœ…đŸŸ đŸ””", "đŸ””đŸŸ abcđŸŸ đŸ””" });
+}
+
+static void test_quantifiers() {
+    // A collection of tests to exercise * + and ? quantifiers
+
+    test_grammar(
+        "* quantifier",
+        // Grammar
+        R"""(start: "a"*)""",
+        // Passing strings
+        { "", "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
+        // Failing strings
+        { "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
+    test_grammar(
+        "+ quantifier",
+        // Grammar
+        R"""(start: "a"+)""",
+        // Passing strings
+        { "a", "aaaaa", "aaaaaaaaaaaaaaaaaa", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" },
+        // Failing strings
+        { "", "b", "ab", "aab", "ba", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaab" });
+    test_grammar("? quantifier",
+                 // Grammar
+                 R"""(start: "a"?)""",
+                 // Passing strings
+                 { "", "a" },
+                 // Failing strings
+                 {
+                     "b",
+                     "ab",
+                     "aa",
+                     "ba",
+                 });
+    test_grammar("mixed quantifiers",
+                 // Grammar
+                 R"""(
+            start: cons+ vowel* cons? (vowel cons)*
+            vowel: /[aeiouy]/
+            cons: /[bcdfghjklmnpqrstvwxyz]/
+            )""",
+                 // Passing strings
+                 {
+                     "yes",
+                     "no",
+                     "noyes",
+                     "crwth",
+                     "four",
+                     "bryyyy",
+                 },
+                 // Failing strings
+                 {
+                     "yess",
+                     "yesno",
+                     "forty",
+                     "catyyy",
+                 });
+    test_grammar("simple exact repetition",
+                 // Grammar
+                 R"""(
+            start: /[ab]{4}/
+        )""",
+                 // Passing strings
+                 {
+                     "aaaa",
+                     "bbbb",
+                     "abab",
+                 },
+                 // Failing strings
+                 {
+                     "a",
+                     "b",
+                     "aaaaa",
+                 });
+    test_grammar("simple min repetition",
+                 // Grammar
+                 R"""(
+            start: /[ab]{4,}/
+        )""",
+                 // Passing strings
+                 {
+                     "aaaa",
+                     "aaaaab",
+                     "bbbb",
+                     "ababab",
+                 },
+                 // Failing strings
+                 {
+                     "",
+                     "aba",
+                 });
+    test_grammar("simple max repetition",
+                 // Grammar
+                 R"""(
+            start: /[ab]{0,4}/
+        )""",
+                 // Passing strings
+                 {
+                     "",
+                     "a",
+                     "aa",
+                     "aaa",
+                     "aaab",
+                 },
+                 // Failing strings
+                 {
+                     "aaaaa",
+                 });
+    // test_grammar("min / max repetition",
+    //              // Grammar
+    //              R"""(
+    //         start: ("0x" /[A-F0-9]{2}/ " "?){3,5}
+    //     )""",
+    //              // Passing strings
+    //              {
+    //                  "0xFF 0x12 0xAB",
+    //                  "0xFF 0x12 0xAB 0x00 0x00",
+    //              },
+    //              // Failing strings
+    //              {
+    //                  "",
+    //                  "0xFF",
+    //                  "0xFF 0x12",
+    //                  "0xFF 0x12 0xAB 0x00 0x00 0x00",
+    //              });
+}
+
+static void test_json_schema() {
+    // Note that this is similar to the regular grammar tests,
+    //  but we convert each json schema to a grammar before parsing.
+    // Otherwise, this test structure is the same.
+
+    test_schema("empty schema (object)",
+                // Schema
+                R"""(
+            {"type":"object"}
+        )""",
+                // Passing strings
+                {
+                    R"""({})""",
+                    R"""({"foo": "bar"})""",
+                },
+                // Failing strings
+                {
+                    "",
+                    "[]",
+                    "null",
+                    R"""("")""",
+                    "true",
+                });
+
+    test_schema(
+        "exotic formats (list)",
+        // Schema
+        R"""({
+            "items": [
+                { "format": "date" },
+                { "format": "uuid" },
+                { "format": "time" },
+                { "format": "date-time" }
+            ]
+        })""",
+        // Passing strings
+        {
+            // "{}", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+            // "[]", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+            R"""(["2012-04-23", "12345678-1234-1234-1234-1234567890ab", "18:25:43.511Z", "2012-04-23T18:25:43.511Z"])""",
+            //R"""(["2012-04-23","12345678-1234-1234-1234-1234567890ab"])""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+            //R"""({"foo": "bar"})""", // NOTE: This string passes for this schema on https://www.jsonschemavalidator.net/ -- should it?
+        },
+        // Failing strings
+        {
+            R"""(["foo", "bar"])""",
+            R"""(["12345678-1234-1234-1234-1234567890ab"])""",
+        });
+
+    test_schema("string",
+                // Schema
+                R"""({
+            "type": "string"
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                    R"""("")""",
+                },
+                // Failing strings
+                {
+                    R"""({})""",
+                    R"""("foo": "bar")""",
+                });
+
+    test_schema("string w/ min length 1",
+                // Schema
+                R"""({
+            "type": "string",
+            "minLength": 1
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                },
+                // Failing strings
+                {
+                    R"""("")""",
+                    R"""({})""",
+                    R"""("foo": "bar")""",
+                });
+
+    test_schema("string w/ min length 3",
+                // Schema
+                R"""({
+                "type": "string",
+                "minLength": 3
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                    R"""("foobar")""",
+                },
+                // Failing strings
+                {
+                    R"""("")""",
+                    R"""("f")""",
+                    R"""("fo")""",
+                });
+
+    test_schema("string w/ max length",
+                // Schema
+                R"""({
+            "type": "string",
+            "maxLength": 3
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                    R"""("")""",
+                    R"""("f")""",
+                    R"""("fo")""",
+                },
+                // Failing strings
+                {
+                    R"""("foobar")""",
+                });
+
+    test_schema("string w/ min & max length",
+                // Schema
+                R"""({
+            "type": "string",
+            "minLength": 1,
+            "maxLength": 4
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                    R"""("bar")""",
+                    R"""("f")""",
+                    R"""("barf")""",
+                },
+                // Failing strings
+                {
+                    R"""("")""",
+                    R"""("barfo")""",
+                    R"""("foobar")""",
+                });
+
+    test_schema("boolean",
+                // Schema
+                R"""({
+            "type": "boolean"
+        })""",
+                // Passing strings
+                {
+                    "true",
+                    "false",
+                },
+                // Failing strings
+                {
+                    R"""("")""",
+                    R"""("true")""",
+                    R"""(True)""",
+                    R"""(FALSE)""",
+                });
+
+    test_schema("integer",
+                // Schema
+                R"""({
+            "type": "integer"
+        })""",
+                // Passing strings
+                {
+                    R"""(0)""",
+                    R"""(12345)""",
+                    R"""(1234567890123456)""",
+                },
+                // Failing strings
+                {
+                    R"""()""",
+                    R"""(01)""",
+                    R"""(007)""",
+                    R"""(12345678901234567  )""",
+                });
+
+    test_schema("string const",
+                // Schema
+                R"""({
+            "const": "foo"
+        })""",
+                // Passing strings
+                {
+                    R"""("foo")""",
+                },
+                // Failing strings
+                {
+                    R"""(foo)""",
+                    R"""("bar")""",
+                });
+
+    test_schema("non-string const",
+                // Schema
+                R"""({
+            "const": true
+        })""",
+                // Passing strings
+                {
+                    R"""(true)""",
+                },
+                // Failing strings
+                {
+                    R"""()""",
+                    R"""(foo)""",
+                    R"""("true")""",
+                });
+
+    test_schema("non-string const",
+                // Schema
+                R"""({
+            "enum": ["red", "amber", "green", null, 42, ["foo"]]
+        })""",
+                // Passing strings
+                {
+                    R"""("red")""",
+                    R"""(null)""",
+                    R"""(42)""",
+                    R"""(["foo"])""",
+                },
+                // Failing strings
+                {
+                    R"""()""",
+                    R"""(420)""",
+                    R"""(true)""",
+                    R"""(foo)""",
+                });
+
+    test_schema("simple pattern",
+                // Schema
+                R"""({
+            "pattern": "^[a-zA-Z0-9_-]*$"
+        })""",
+                // Passing strings
+                {
+                    R"""("")""",
+                    R"""("He_llo-12")""",
+                },
+                // Failing strings
+                {
+                    R"""("!")""",
+                    R"""("Hello World")""",
+                });
+
+    test_schema("pattern with escapes",
+                // Schema
+                R"""({
+            "pattern": "^a\\^\\$\\.\\[\\]\\(\\)\\|\\{\\}\\*\\+\\?b$"
+        })""",
+                // Passing strings
+                {
+                    R"""("a^$.[]()|{}*+?b")""",
+                },
+                // Failing strings
+                {
+                    R"""("ab")""",
+                });
+
+    test_schema("",
+                // Schema
+                R"""(
+            {
+                "type": ["array", "null"],
+                "items": { "type": "string" }
+            }
+        )""",
+                // Passing strings
+                {
+                    "null",
+                    "[]",
+                    "[\"123\"]",
+                    "[\"foo\", \"bar\"]",
+                },
+                // Failing strings
+                {
+                    "",
+                    "[123]",
+                    "\"foo\"",
+                    "[\"foo\", 42]",
+                });
+
+    test_schema("min+max items",
+                // Schema
+                R"""({
+            "items": {
+                "type": ["number", "integer"]
+            },
+            "minItems": 3,
+            "maxItems": 5
+        })""",
+                // Passing strings
+                {
+                    R"""([1, 2, 3])""",
+                    R"""([1, 2, 3, 4])""",
+                    R"""([1, 2, 3, 4, 5])""",
+                    // this is in fact correct; keyword do not apply if the type is wrong
+                    R"""(1)""",
+                },
+                // Failing strings
+                {
+                    R"""([1, 2])""",
+                    R"""([1, 2, 3, 4, 5, 6])""",
+                });
+
+    // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
+    test_schema("object properties",
+                // Schema
+                R"""({
+            "type": "object",
+            "properties": {
+                "number": { "type": "number" },
+                "street_name": { "type": "string" },
+                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
+            },
+            "additionalProperties": false
+        })""",
+                // Passing strings
+                {
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+                    // "By default, leaving out properties is valid"
+                    R"""({ "street_name": "Pennsylvania" })""",
+                    R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
+                    // "By extension, even an empty object is valid"
+                    R"""({})""",
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
+                },
+                // Failing strings
+                {
+                    // Change datatype from number to string
+                    R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+                    // Reorder properties
+                    R"""({ "street_name": "Pennsylvania", "number": 1600 })""",
+                    // Reorder properties
+                    R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+                    // Additional properties set to false
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
+
+                });
+
+    test_schema("additional properties can't override other properties",
+                R"""({
+            "properties": {
+                "a": {"type": "integer"},
+                "b": {"type": "integer"}
+            },
+            "additionalProperties": true
+        })""",
+                // Passing strings
+                {
+                    R"""({"a": 42})""",
+                    R"""({"c": ""})""",
+                    R"""({"a": 42, "c": ""})""",
+                    R"""({"a_": ""})""",
+                },
+                // Failing strings
+                {
+                    R"""()""",
+                    R"""({"a": ""})""",
+                    R"""({"a": "", "b": ""})""",
+                });
+
+    // Properties (from: https://json-schema.org/understanding-json-schema/reference/object#properties)
+    test_schema("object properties, additionalProperties: true",
+                // Schema
+                R"""({
+            "type": "object",
+            "properties": {
+                "number": { "type": "number" },
+                "street_name": { "type": "string" },
+                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
+            },
+            "additionalProperties": true
+        })""",
+                // Passing strings
+                {
+                    // "By extension, even an empty object is valid"
+                    R"""({})""",
+                    R"""({"number":1600,"street_name":"Pennsylvania","street_type":"Avenue"})""",
+                    // "By default, leaving out properties is valid"
+                    R"""({ "street_name": "Pennsylvania" })""",
+                    R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
+                    // "By default, providing additional properties is valid"
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue", "direction":"NW"})""",
+                    R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
+                },
+                // Failing strings
+                {
+                    // Change datatype from number to string
+                    R"""({ "number": "1600", "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+                    // Reorder properties
+                    R"""({ "street_name": "Pennsylvania", "number": 1600, "street_type":"Avenue"})""",
+                });
+
+    // Additional properties: false
+    test_schema(
+        "required + optional props each in original order",
+        // Schema
+        R"""({
+            "type": "object",
+            "properties": {
+                "number": { "type": "number" },
+                "street_name": { "type": "string" },
+                "street_type": { "enum": ["Street", "Avenue", "Boulevard"] }
+            },
+            "additionalProperties": false
+        })""",
+        // Passing strings
+        {
+            R"""({ "street_name": "Pennsylvania" })""",
+            R"""({ "number": 1600, "street_type":"Avenue"})""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania" })""",
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type":"Avenue"})""",
+            // Spaces are permitted around enum values
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue" })""",
+        },
+        // Failing strings
+        {
+            // Reorder properties
+            R"""({ "street_type": "Avenue", "number": 1600 })""",
+            // Add "direction"
+            R"""({ "number": 1600, "street_name": "Pennsylvania", "street_type": "Avenue", "direction": "NW" })""",
+        });
+
+    test_schema("required + optional props each in original order",
+                // Schema
+                R"""({
+            "properties": {
+                "b": {"type": "string"},
+                "a": {"type": "string"},
+                "d": {"type": "string"},
+                "c": {"type": "string"}
+            },
+            "required": ["a", "b"],
+            "additionalProperties": false
+        })""",
+                // Passing strings
+                {
+                    R"""({"b": "foo", "a": "bar"})""",
+                    R"""({"b":"foo","a":"bar","d":"qux"})""",
+                    R"""({"b":"foo", "a":"bar", "d":"qux", "c":"baz"})""",
+                },
+                // Failing strings
+                {
+                    R"""({"a": "foo", "b": "bar"})""",
+                    R"""({"b": "bar"})""",
+                    R"""({"a": "foo", "c": "baz"})""",
+                    R"""({"a":"foo", "b":"bar", "c":"baz", "d":"qux"})""",
+                });
+
+    // NOTE: Example from https://json-schema.org/learn/getting-started-step-by-step#define-required-properties
+    test_schema(
+        "required props",
+        // Schema
+        R"""({
+            "$schema": "https://json-schema.org/draft/2020-12/schema",
+            "$id": "https://example.com/product.schema.json",
+            "title": "Product",
+            "description": "A product from Acme's catalog",
+            "type": "object",
+            "properties": {
+                "productId": {
+                "description": "The unique identifier for a product",
+                "type": "integer"
+                },
+                "productName": {
+                "description": "Name of the product",
+                "type": "string"
+                },
+                "price": {
+                "description": "The price of the product",
+                "type": "number",
+                "exclusiveMinimum": 0
+                },
+                "tags": {
+                "description": "Tags for the product",
+                "type": "array",
+                "items": {
+                    "type": "string"
+                },
+                "minItems": 1,
+                "DISABLED_uniqueItems": true
+                },
+                "dimensions": {
+                "type": "object",
+                "properties": {
+                    "length": {
+                    "type": "number"
+                    },
+                    "width": {
+                    "type": "number"
+                    },
+                    "height": {
+                    "type": "number"
+                    }
+                },
+                "required": [ "length", "width", "height" ]
+                }
+            },
+            "required": [ "productId", "productName", "price" ]
+        })""",
+        // Passing strings
+        {
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50})""",
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"]})""",
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green"], "dimensions": {"length": 785, "width": 250.5, "height": -0.359}})""",
+        },
+        // Failing strings
+        {
+            R"""({})""",  // Missing all required properties
+            R"""({"productName": "A green door", "price": 12.50, "productId": 1})""",  // Out of order properties
+            // `exclusiveMinimum` is OK for llg
+            R"""({"productId": 1, "productName": "A green door", "price": -12.50})""",
+            R"""({"productId": 1, "productName": "A green door"})""",  // Missing required property (price)
+            R"""({"productName": "A green door", "price": 12.50})""",  // Missing required property (productId)
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": []})""",  // tags is empty, but minItems is 1
+            R"""({"productId": 1, "productName": "A green door", "price": 12.50, "dimensions": {"length": 785, "width": 250.5, "height": -0.359}, "tags": ["home", "green"]})""",  // Tags and dimensions are out of order
+            // TODO: The following line should fail, but currently it passes. `uniqueItems` is not supported, as it would likely be too difficult to implement.
+            // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""",
+        });
+}
+
+int main(int argc, const char ** argv) {
+    fprintf(stdout, "Running llguidance integration tests...\n");
+
+    if (argc != 2) {
+        fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
+        return 1;
+    }
+
+    const char * vocab_file = argv[1];
+
+    fprintf(stderr, "reading vocab from: '%s'\n", vocab_file);
+
+    llama_model *   model;
+    llama_context * ctx;
+
+    llama_backend_init();
+
+    // load the vocab
+    {
+        auto mparams = llama_model_default_params();
+
+        mparams.vocab_only = true;
+
+        model = llama_model_load_from_file(vocab_file, mparams);
+
+        if (model == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
+            return 1;
+        }
+
+        // needed?
+        auto cparams = llama_context_default_params();
+
+        ctx = llama_init_from_model(model, cparams);
+
+        if (ctx == NULL) {
+            fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, vocab_file);
+            llama_model_free(model);
+            return 1;
+        }
+    }
+
+    vocab = llama_model_get_vocab(model);
+
+    test_simple_grammar();
+    test_complex_grammar();
+    test_special_chars();
+    test_quantifiers();
+    test_json_schema();
+    fprintf(stdout, "All tests passed.\n");
+    return 0;
+}
index 9d2db91f52c35de2844174c031e18393ff3811d5..f38994c925e09a99edaea831da1e066d4011449f 100755 (executable)
@@ -1246,7 +1246,7 @@ int main() {
 
     test_all("C++", [](const TestCase & tc) {
         try {
-            tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema)));
+            tc.verify(json_schema_to_grammar(nlohmann::ordered_json::parse(tc.schema), true));
             tc.verify_status(SUCCESS);
         } catch (const std::runtime_error & ex) {
             fprintf(stderr, "Error: %s\n", ex.what());