]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : return nullptr from llama_grammar_init (#8093)
authorDaniel Bevenius <redacted>
Tue, 25 Jun 2024 19:07:28 +0000 (21:07 +0200)
committerGitHub <redacted>
Tue, 25 Jun 2024 19:07:28 +0000 (15:07 -0400)
* llama : return nullptr from llama_grammar_init

This commit updates llama_grammar_init to return nullptr instead of
throwing an exception.

The motivation for this is that this function is declared inside an
extern "C" block and is intended/may be used from C code which will not
be able to handle exceptions thrown, and results in undefined behavior.

On Windows and using MSVC the following warning is currently generated:
```console
C:\llama.cpp\llama.cpp(13998,1): warning C4297: 'llama_grammar_init':
function assumed not to throw an exception but does
C:\llama.cpp\llama.cpp(13998,1): message :
__declspec(nothrow), throw(), noexcept(true), or noexcept was specified
on the function
```

Signed-off-by: Daniel Bevenius <redacted>
* squash! llama : return nullptr from llama_grammar_init

Add checks for nullptr when calling llama_grammar_init.

Signed-off-by: Daniel Bevenius <redacted>
---------

Signed-off-by: Daniel Bevenius <redacted>
Co-authored-by: Clint Herron <redacted>
common/sampling.cpp
examples/gbnf-validator/gbnf-validator.cpp
llama.cpp
llama.h
tests/test-grammar-integration.cpp
tests/test-llama-grammar.cpp

index f1f80351637f0965ebffa64f74e3c9c837963259..9f332fe5736838d58da96edfa8a39fb2ac4a2329 100644 (file)
@@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
 
         std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
 
-        result->grammar = llama_grammar_init(
+        struct llama_grammar * grammar = llama_grammar_init(
                 grammar_rules.data(),
                 grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
+        if (grammar == nullptr) {
+            throw std::runtime_error("Failed to initialize llama_grammar");
+        }
+        result->grammar = grammar;
     }
 
     result->prev.resize(params.n_prev);
@@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
     if (!ctx->parsed_grammar.rules.empty()) {
         std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
 
-        ctx->grammar = llama_grammar_init(
+        struct llama_grammar * grammar = llama_grammar_init(
                 grammar_rules.data(),
                 grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
+        if (grammar == nullptr) {
+            throw std::runtime_error("Failed to initialize llama_grammar");
+        }
+        ctx->grammar = grammar;
     }
 
     std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
index 0406dc3398b8ae557ea192a134e647050a941eaf..dd53ba9b1d5510673b67e62afa8bb09cd68d8383 100644 (file)
@@ -101,7 +101,9 @@ int main(int argc, char** argv) {
     auto grammar = llama_grammar_init(
             grammar_rules.data(),
             grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
-
+    if (grammar == nullptr) {
+        throw std::runtime_error("Failed to initialize llama_grammar");
+    }
     // Read the input file
     std::string input_str;
     {
index 33e6cb7229aab2489565d258811505f9d7f4c1fd..dd2823e65c4b763069ff62a245165f29824b295e 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -14500,7 +14500,8 @@ struct llama_grammar * llama_grammar_init(
             continue;
         }
         if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
-            throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
+            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
+            return nullptr;
         }
     }
 
diff --git a/llama.h b/llama.h
index 53e06d9db52733d1c779a77b7b2d9e348e58d610..82d15747f4662d56841eb16c9dd97e4de57f76ab 100644 (file)
--- a/llama.h
+++ b/llama.h
@@ -924,6 +924,12 @@ extern "C" {
     // Grammar
     //
 
+    /// Initialize a llama_grammar.
+    ///
+    /// @param rules The rule elements of the grammar to initialize.
+    /// @param n_rules The number of rules.
+    /// @param start_rule_index The index of the root rule (the starting point of the grammar).
+    /// @return The initialized llama_grammar or nullptr if initialization failed.
     LLAMA_API struct llama_grammar * llama_grammar_init(
             const llama_grammar_element ** rules,
                                  size_t    n_rules,
index 5b3992236c26c02b41cece71d121c309c59c8f3a..5750d362a7247afbe3357c6b747c25604d4399aa 100644 (file)
@@ -36,10 +36,10 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
 static bool test_build_grammar_fails(const std::string & grammar_str) {
     fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
     bool grammar_fails = false;
-    try {
-        build_grammar(grammar_str);
+    llama_grammar * grammar = build_grammar(grammar_str);
+    if (grammar != nullptr) {
         fprintf(stderr, "  ❌ Expected build failure, but succeeded\n");
-    } catch (const std::exception & err) {
+    } else {
         grammar_fails = true;
         fprintf(stdout, "  ✅︎\n");
     }
index 27ca4d2656c5dd8fe1006525f93df5575193d157..c8badb20630761d5a590209632e6ac6bae1b5c6f 100644 (file)
@@ -116,6 +116,10 @@ int main()
     std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
     grammar = llama_grammar_init(
         grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+    if (grammar == nullptr)
+    {
+        throw std::runtime_error("Failed to initialize llama_grammar");
+    }
 
     std::vector<std::vector<llama_grammar_element>> expected_stacks = {
         {