]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fix memory bug in grammar parser (#7194)
authorJustine Tunney <redacted>
Fri, 10 May 2024 11:01:08 +0000 (07:01 -0400)
committerGitHub <redacted>
Fri, 10 May 2024 11:01:08 +0000 (21:01 +1000)
The llama.cpp grammar parser had a bug where forgetting to add a closing
quotation mark to strings would cause parsing to crash. Anyone running a
server on a public endpoint is advised to upgrade. To reproduce this bug

    ./llamafile -m foo.gguf -p bar --grammar 'root::="'

Credit for discovering and reporting this issue goes to Eclypsium
Security Researcher Richard Johnson <redacted>.

common/common.cpp
common/grammar-parser.cpp
examples/llava/llava-cli.cpp
examples/main/main.cpp

index 484e673349071b77b8bbdbd5b1d1a45513a82ba1..ba1ecf0e59c8beb61bac94b1ec201557b9759755 100644 (file)
@@ -1371,14 +1371,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
         if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
             std::replace(arg.begin(), arg.end(), '_', '-');
         }
-
         if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
             throw std::invalid_argument("error: unknown argument: " + arg);
         }
-    }
-
-    if (invalid_param) {
-        throw std::invalid_argument("error: invalid parameter for argument: " + arg);
+        if (invalid_param) {
+            throw std::invalid_argument("error: invalid parameter for argument: " + arg);
+        }
     }
 
     if (params.prompt_cache_all &&
index 2a1301569793ad8e902b969f8c15aff0cc96f214..fecb7cd713ea196a819440edd60431e7420400d3 100644 (file)
@@ -142,6 +142,9 @@ namespace grammar_parser {
                 pos++;
                 last_sym_start = out_elements.size();
                 while (*pos != '"') {
+                    if (!*pos) {
+                        throw std::runtime_error("unexpected end of input");
+                    }
                     auto char_pair = parse_char(pos);
                          pos       = char_pair.second;
                     out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
@@ -156,6 +159,9 @@ namespace grammar_parser {
                 }
                 last_sym_start = out_elements.size();
                 while (*pos != ']') {
+                    if (!*pos) {
+                        throw std::runtime_error("unexpected end of input");
+                    }
                     auto char_pair = parse_char(pos);
                          pos       = char_pair.second;
                     enum llama_gretype type = last_sym_start < out_elements.size()
@@ -164,6 +170,9 @@ namespace grammar_parser {
 
                     out_elements.push_back({type, char_pair.first});
                     if (pos[0] == '-' && pos[1] != ']') {
+                        if (!pos[1]) {
+                            throw std::runtime_error("unexpected end of input");
+                        }
                         auto endchar_pair = parse_char(pos + 1);
                              pos          = endchar_pair.second;
                         out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
index 157a680b5ecdb04d3ed5868b28fd52db829bc36b..da60ddf2f057dba3a5de2679707faf7b9e1cd840 100644 (file)
@@ -189,6 +189,11 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
     LOG_TEE("\n");
 
     struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
+    if (!ctx_sampling) {
+        fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
+        exit(1);
+    }
+
     std::string response = "";
     for (int i = 0; i < max_tgt_len; i++) {
         const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
index f3e445c16d6a9ba5512f27f00a5a41cc0a38bf1a..9dee41001f12c17124b00da0204f19f020331adb 100644 (file)
@@ -523,6 +523,10 @@ int main(int argc, char ** argv) {
     }
 
     struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
+    if (!ctx_sampling) {
+        fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
+        exit(1);
+    }
 
     while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
         // predict