]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Allow passing grammar to completion endpoint (#2532)
authorMartin Krasser <redacted>
Tue, 8 Aug 2023 13:29:19 +0000 (15:29 +0200)
committerGitHub <redacted>
Tue, 8 Aug 2023 13:29:19 +0000 (16:29 +0300)
* Allow passing grammar to completion endpoint

Makefile
examples/server/README.md
examples/server/server.cpp

index 897c5cb9abccabc791b05fb05dd6d312a7542473..32598edfe847d6db1e67ab153fd6b8dd7d617446 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -380,7 +380,7 @@ embedding: examples/embedding/embedding.cpp                   build-info.h ggml.
 save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS)
        $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
 
-server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o $(OBJS)
+server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
        $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
 
 $(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)
index aee31ae42e517927a958b8c53db33e27b1f5ede5..e56ca063a9f0e66e3d44e2fd91bfa63b66e525ae 100644 (file)
@@ -151,6 +151,8 @@ node .
 
     `mirostat_eta`: Set the Mirostat learning rate, parameter eta (default: 0.1).
 
+    `grammar`: Set grammar for grammar-based sampling (default: no grammar)
+
     `seed`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).
 
     `ignore_eos`: Ignore end of stream token and continue generating (default: false).
index 6f7a66da108c855f8aa9834a68a7ccf4d9cbff42..10ae264f516f4e80af02b7c320fa45a24cc06669 100644 (file)
@@ -1,6 +1,7 @@
 #include "common.h"
 #include "llama.h"
 #include "build-info.h"
+#include "grammar-parser.h"
 
 #ifndef NDEBUG
 // crash the server in debug mode, otherwise send an http 500 error
@@ -195,6 +196,8 @@ struct llama_server_context
     llama_context *ctx = nullptr;
     gpt_params params;
 
+    llama_grammar *grammar = nullptr;
+
     bool truncated = false;
     bool stopped_eos = false;
     bool stopped_word = false;
@@ -226,6 +229,7 @@ struct llama_server_context
     void rewind()
     {
         params.antiprompt.clear();
+        params.grammar.clear();
         num_prompt_tokens = 0;
         num_tokens_predicted = 0;
         generated_text = "";
@@ -237,6 +241,7 @@ struct llama_server_context
         stopped_limit = false;
         stopping_word = "";
         multibyte_pending = 0;
+        grammar = nullptr;
 
         n_remain = 0;
         n_past = 0;
@@ -257,6 +262,33 @@ struct llama_server_context
         return true;
     }
 
+    bool loadGrammar()
+    {
+        if (!params.grammar.empty()) {
+            grammar_parser::parse_state parsed_grammar;
+
+            parsed_grammar = grammar_parser::parse(params.grammar.c_str());
+            // will be empty (default) if there are parse errors
+            if (parsed_grammar.rules.empty()) {
+                LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
+                return false;
+            }
+            grammar_parser::print_grammar(stderr, parsed_grammar);
+
+            {
+                auto it = params.logit_bias.find(llama_token_eos());
+                if (it != params.logit_bias.end() && it->second == -INFINITY) {
+                    LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
+                }
+            }
+
+            std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
+            grammar = llama_grammar_init(
+                grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+        }
+        return true;
+    }
+
     void loadPrompt()
     {
         params.prompt.insert(0, 1, ' '); // always add a first space
@@ -420,6 +452,10 @@ struct llama_server_context
                 logits[llama_token_nl()] = nl_logit;
             }
 
+            if (grammar != nullptr) {
+                llama_sample_grammar(ctx, &candidates_p, grammar);
+            }
+
             if (temp <= 0)
             {
                 // Greedy sampling
@@ -457,10 +493,15 @@ struct llama_server_context
                 }
             }
 
+            if (grammar != nullptr) {
+                llama_grammar_accept_token(ctx, grammar, result.tok);
+            }
+
             for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
             {
                 result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
             }
+
             last_n_tokens.erase(last_n_tokens.begin());
             last_n_tokens.push_back(result.tok);
             num_tokens_predicted++;
@@ -947,6 +988,7 @@ static json format_generation_settings(llama_server_context &llama)
         {"stream", llama.stream},
         {"logit_bias", llama.params.logit_bias},
         {"n_probs", llama.params.n_probs},
+        {"grammar", llama.params.grammar},
     };
 }
 
@@ -1048,6 +1090,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
     llama.params.n_keep = body.value("n_keep", default_params.n_keep);
     llama.params.seed = body.value("seed", default_params.seed);
     llama.params.prompt = body.value("prompt", default_params.prompt);
+    llama.params.grammar = body.value("grammar", default_params.grammar);
     llama.params.n_probs = body.value("n_probs", default_params.n_probs);
 
     llama.params.logit_bias.clear();
@@ -1179,6 +1222,12 @@ int main(int argc, char **argv)
 
         parse_options_completion(json::parse(req.body), llama);
 
+        if (!llama.loadGrammar())
+        {
+            res.status = 400;
+            return;
+        }
+
         llama.loadPrompt();
         llama.beginCompletion();
 
@@ -1334,8 +1383,12 @@ int main(int argc, char **argv)
 
     svr.set_error_handler([](const Request &, Response &res)
                           {
-        res.set_content("File Not Found", "text/plain");
-        res.status = 404; });
+        if (res.status == 400) {
+            res.set_content("Invalid request", "text/plain");
+        } else {
+            res.set_content("File Not Found", "text/plain");
+            res.status = 404;
+        } });
 
     // set timeouts and change hostname and port
     svr.set_read_timeout(sparams.read_timeout);
@@ -1363,6 +1416,9 @@ int main(int argc, char **argv)
         return 1;
     }
 
+    if (llama.grammar != nullptr) {
+        llama_grammar_free(llama.grammar);
+    }
     llama_backend_free();
 
     return 0;