]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
`main`: add --json-schema / -j flag (#6659)
authorOlivier Chafik <redacted>
Mon, 15 Apr 2024 17:35:21 +0000 (18:35 +0100)
committerGitHub <redacted>
Mon, 15 Apr 2024 17:35:21 +0000 (18:35 +0100)
* main: add --json-schema / -j

* json: move json-schema-to-grammar to common lib

* json: fix zig build

Makefile
build.zig
common/CMakeLists.txt
common/common.cpp
examples/main/README.md
examples/server/CMakeLists.txt
tests/CMakeLists.txt

index 7a69ad1b3c14fe0b1f6fe64797b97cf3bad9ca06..8f3e17da4c57bafa7122aea8b5f7758447265c1c 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -688,7 +688,7 @@ llama.o: llama.cpp unicode.h ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml
        $(CXX) $(CXXFLAGS) -c $< -o $@
 
 COMMON_H_DEPS = common/common.h common/sampling.h common/log.h
-COMMON_DEPS   = common.o sampling.o grammar-parser.o build-info.o
+COMMON_DEPS   = common.o sampling.o grammar-parser.o build-info.o json-schema-to-grammar.o
 
 common.o: common/common.cpp $(COMMON_H_DEPS)
        $(CXX) $(CXXFLAGS) -c $< -o $@
@@ -756,7 +756,7 @@ batched: examples/batched/batched.cpp                         ggml.o llama.o $(C
        $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 
-batched-bench: examples/batched-bench/batched-bench.cpp       build-info.o ggml.o llama.o common.o $(OBJS)
+batched-bench: examples/batched-bench/batched-bench.cpp       build-info.o ggml.o llama.o $(COMMON_DEPS) $(OBJS)
        $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 
@@ -788,7 +788,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
        $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
 
-server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp json-schema-to-grammar.o common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
+server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
        $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
        $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
 
index 7f36e596888da120684ec226403fdc0cd1c435c0..e05ca2120ba4cbd1d13b0532ca07b0fd6e0044e9 100644 (file)
--- a/build.zig
+++ b/build.zig
@@ -128,14 +128,14 @@ pub fn build(b: *std.build.Builder) !void {
     const clip = make.obj("clip", "examples/llava/clip.cpp");
     const llava = make.obj("llava", "examples/llava/llava.cpp");
 
-    _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, console, grammar_parser });
-    _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo });
-    _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo });
-    _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo });
-    _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train });
-    _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, train });
+    _ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
+    _ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
+    _ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
+    _ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
+    _ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
+    _ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, train });
 
-    const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, buildinfo, sampling, grammar_parser, json_schema_to_grammar, clip, llava });
+    const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, grammar_parser, clip, llava });
     if (server.target.isWindows()) {
         server.linkSystemLibrary("ws2_32");
     }
index 1d840e5f7387730fc92d03883cb36efa06286a74..0ec8d6d8d03b5319b7095ab84ccd60cc74ff44a8 100644 (file)
@@ -47,9 +47,6 @@ if (BUILD_SHARED_LIBS)
     set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
 endif()
 
-set(TARGET json-schema-to-grammar)
-add_library(${TARGET} OBJECT json-schema-to-grammar.cpp json-schema-to-grammar.h)
-
 set(TARGET common)
 
 add_library(${TARGET} STATIC
@@ -63,6 +60,7 @@ add_library(${TARGET} STATIC
     grammar-parser.h
     grammar-parser.cpp
     json.hpp
+    json-schema-to-grammar.cpp
     train.h
     train.cpp
     ngram-cache.h
index dda514785171b0fd3b51cb91a4db3b3ee901a647..52576cba37bdd5efe9fcadb4427ba6e7b0e24d1f 100644 (file)
@@ -1,4 +1,6 @@
 #include "common.h"
+#include "json.hpp"
+#include "json-schema-to-grammar.h"
 #include "llama.h"
 
 #include <algorithm>
@@ -68,6 +70,8 @@
 #define LLAMA_CURL_MAX_HEADER_LENGTH 256
 #endif // LLAMA_USE_CURL
 
+using json = nlohmann::ordered_json;
+
 int32_t get_num_physical_cores() {
 #ifdef __linux__
     // enumerate the set of thread siblings, num entries is num cores
@@ -1148,6 +1152,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         );
         return true;
     }
+    if (arg == "-j" || arg == "--json-schema") {
+        if (++i >= argc) {
+            invalid_param = true;
+            return true;
+        }
+        sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
+        return true;
+    }
     if (arg == "--override-kv") {
         if (++i >= argc) {
             invalid_param = true;
@@ -1353,6 +1365,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     printf("                        or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
     printf("  --grammar GRAMMAR     BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
     printf("  --grammar-file FNAME  file to read grammar from\n");
+    printf("  -j SCHEMA, --json-schema SCHEMA\n");
+    printf("                        JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
+    printf("                        For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
     printf("  --cfg-negative-prompt PROMPT\n");
     printf("                        negative prompt to use for guidance. (default: empty)\n");
     printf("  --cfg-negative-prompt-file FNAME\n");
index 10a589cebb34536496d16f5419eda2122b7704d1..649f4e0f35820e75a01917c907211141b7eb7f67 100644 (file)
@@ -304,10 +304,12 @@ These options help improve the performance and memory usage of the LLaMA models.
 
 -   `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. **Note**: Restoring a cached prompt does not imply restoring the exact state of the session at the point it was saved. So even when specifying a specific seed, you are not guaranteed to get the same sequence of tokens as the original generation.
 
-### Grammars
+### Grammars & JSON schemas
 
 -   `--grammar GRAMMAR`, `--grammar-file FILE`: Specify a grammar (defined inline or in a file) to constrain model output to a specific format. For example, you could force the model to output JSON or to speak only in emojis. See the [GBNF guide](../../grammars/README.md) for details on the syntax.
 
+-   `--json-schema SCHEMA`: Specify a [JSON schema](https://json-schema.org/) to constrain model output to (e.g. `{}` for any JSON object, or `{"items": {"type": "string", "minLength": 10, "maxLength": 100}, "minItems": 10}` for a JSON array of strings with size constraints). If a schema uses external `$ref`s, you should use `--grammar "$( python examples/json_schema_to_grammar.py myschema.json )"` instead.
+
 ### Quantization
 
 For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-and-quantize).
index d2ee47d01d7678d3445dc208e1f3ed422c1bae9f..61f58417449a8a842d117077ea658872dfd989a2 100644 (file)
@@ -11,7 +11,7 @@ install(TARGETS ${TARGET} RUNTIME)
 target_compile_definitions(${TARGET} PRIVATE
     SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
 )
-target_link_libraries(${TARGET} PRIVATE common json-schema-to-grammar ${CMAKE_THREAD_LIBS_INIT})
+target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
 if (LLAMA_SERVER_SSL)
     find_package(OpenSSL REQUIRED)
     target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)
index b5d7bb59c60dbef59243b87378903eeda3c86db0..89f23ca2de9986e9a83362736fb03f5990ff6351 100644 (file)
@@ -25,7 +25,7 @@ function(llama_test source)
 
     add_executable(${TEST_TARGET} ${source} get-model.cpp)
     install(TARGETS ${TEST_TARGET} RUNTIME)
-    target_link_libraries(${TEST_TARGET} PRIVATE common json-schema-to-grammar)
+    target_link_libraries(${TEST_TARGET} PRIVATE common)
     add_test(
         NAME ${TEST_TARGET}
         WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}