]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
JSON: [key] -> .at(key), assert() -> GGML_ASSERT (#7143)
authorJohannes Gäßler <redacted>
Wed, 8 May 2024 19:53:08 +0000 (21:53 +0200)
committerGitHub <redacted>
Wed, 8 May 2024 19:53:08 +0000 (21:53 +0200)
common/common.cpp
common/json-schema-to-grammar.h
examples/server/server.cpp
examples/server/utils.hpp
tests/test-json-schema-to-grammar.cpp

index 4a9da284e7ec9d7ecdcd2427d1c90ec17a0805c8..0535508ba98dfafd8ce5658b278160b20fb825a4 100644 (file)
@@ -1,4 +1,6 @@
 #include "common.h"
+// Change JSON_ASSERT from assert() to GGML_ASSERT:
+#define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
 #include "json-schema-to-grammar.h"
 #include "llama.h"
@@ -1969,18 +1971,18 @@ static bool llama_download_file(const std::string & url, const std::string & pat
             try {
                 metadata_in >> metadata;
                 fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
-                if (metadata.contains("url") && metadata["url"].is_string()) {
-                    auto previous_url = metadata["url"].get<std::string>();
+                if (metadata.contains("url") && metadata.at("url").is_string()) {
+                    auto previous_url = metadata.at("url").get<std::string>();
                     if (previous_url != url) {
                         fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
                         return false;
                     }
                 }
-                if (metadata.contains("etag") && metadata["etag"].is_string()) {
-                    etag = metadata["etag"];
+                if (metadata.contains("etag") && metadata.at("etag").is_string()) {
+                    etag = metadata.at("etag");
                 }
-                if (metadata.contains("lastModified") && metadata["lastModified"].is_string()) {
-                    last_modified = metadata["lastModified"];
+                if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
+                    last_modified = metadata.at("lastModified");
                 }
             } catch (const nlohmann::json::exception & e) {
                 fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
index e1abed30375826bbf7f201a0a17ff56c6c810630..41623b34645287ab14a24bcb729c79141129cf77 100644 (file)
@@ -1,4 +1,8 @@
 #pragma once
+
+#include "ggml.h"
+// Change JSON_ASSERT from assert() to GGML_ASSERT:
+#define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
 
 std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
index 06c0be56749ab3991b2c6e0bcfbedaf4243f0415..305f79492a0552b63a777629abda5f7350f2239a 100644 (file)
@@ -12,6 +12,8 @@
 // increase max payload length to allow use of larger context size
 #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
 #include "httplib.h"
+// Change JSON_ASSERT from assert() to GGML_ASSERT:
+#define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
 
 // auto generated files (update with ./deps.sh)
@@ -859,7 +861,7 @@ struct server_context {
         slot.sparams.min_keep          = json_value(data, "min_keep",          default_sparams.min_keep);
 
         // process "json_schema" and "grammar"
-        if (data.contains("json_schema") && !data["json_schema"].is_null() && data.contains("grammar") && !data["grammar"].is_null()) {
+        if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
             send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
             return false;
         } else if (data.contains("json_schema") && !data.contains("grammar")) {
@@ -1512,7 +1514,7 @@ struct server_context {
         // add subtasks
         for (int i = 0; i < prompt_count; i++) {
             json subtask_data = multiprompt_task.data;
-            subtask_data["prompt"] = subtask_data["prompt"][i];
+            subtask_data["prompt"] = subtask_data.at("prompt")[i];
 
             // subtasks inherit everything else (infill mode, embedding mode, etc.)
             request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding);
@@ -1532,7 +1534,7 @@ struct server_context {
                     }
 
                     if (task.data.contains("system_prompt")) {
-                        system_prompt_set(task.data["system_prompt"]);
+                        system_prompt_set(task.data.at("system_prompt"));
 
                         for (server_slot & slot : slots) {
                             slot.n_past    = 0;
@@ -1644,7 +1646,7 @@ struct server_context {
                 } break;
             case SERVER_TASK_TYPE_SLOT_SAVE:
                 {
-                    int id_slot = task.data["id_slot"];
+                    int id_slot = task.data.at("id_slot");
                     server_slot * slot = get_slot(id_slot);
                     if (slot == nullptr) {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -1654,8 +1656,8 @@ struct server_context {
                     const size_t token_count = slot->cache_tokens.size();
                     const int64_t t_start = ggml_time_us();
 
-                    std::string filename = task.data["filename"];
-                    std::string filepath = task.data["filepath"];
+                    std::string filename = task.data.at("filename");
+                    std::string filepath = task.data.at("filepath");
 
                     const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
 
@@ -1679,7 +1681,7 @@ struct server_context {
                 } break;
             case SERVER_TASK_TYPE_SLOT_RESTORE:
                 {
-                    int id_slot = task.data["id_slot"];
+                    int id_slot = task.data.at("id_slot");
                     server_slot * slot = get_slot(id_slot);
                     if (slot == nullptr) {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -1688,8 +1690,8 @@ struct server_context {
 
                     const int64_t t_start = ggml_time_us();
 
-                    std::string filename = task.data["filename"];
-                    std::string filepath = task.data["filepath"];
+                    std::string filename = task.data.at("filename");
+                    std::string filepath = task.data.at("filepath");
 
                     slot->cache_tokens.resize(slot->n_ctx);
                     size_t token_count = 0;
@@ -1721,7 +1723,7 @@ struct server_context {
                 } break;
             case SERVER_TASK_TYPE_SLOT_ERASE:
                 {
-                    int id_slot = task.data["id_slot"];
+                    int id_slot = task.data.at("id_slot");
                     server_slot * slot = get_slot(id_slot);
                     if (slot == nullptr) {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -3136,8 +3138,8 @@ int main(int argc, char ** argv) {
                     server_task_result result = ctx_server.queue_results.recv(task.id);
                     ctx_server.queue_results.remove_waiting_task_id(task.id);
 
-                    const int n_idle_slots       = result.data["idle"];
-                    const int n_processing_slots = result.data["processing"];
+                    const int n_idle_slots       = result.data.at("idle");
+                    const int n_processing_slots = result.data.at("processing");
 
                     json health = {
                         {"status",           "ok"},
@@ -3147,7 +3149,7 @@ int main(int argc, char ** argv) {
 
                     res.status = 200; // HTTP OK
                     if (sparams.slots_endpoint && req.has_param("include_slots")) {
-                        health["slots"] = result.data["slots"];
+                        health["slots"] = result.data.at("slots");
                     }
 
                     if (n_idle_slots == 0) {
@@ -3191,7 +3193,7 @@ int main(int argc, char ** argv) {
         server_task_result result = ctx_server.queue_results.recv(task.id);
         ctx_server.queue_results.remove_waiting_task_id(task.id);
 
-        res.set_content(result.data["slots"].dump(), "application/json");
+        res.set_content(result.data.at("slots").dump(), "application/json");
         res.status = 200; // HTTP OK
     };
 
@@ -3218,32 +3220,32 @@ int main(int argc, char ** argv) {
 
         json data = result.data;
 
-        const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
-        const uint64_t t_prompt_processing       = data["t_prompt_processing"];
+        const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed");
+        const uint64_t t_prompt_processing       = data.at("t_prompt_processing");
 
-        const uint64_t n_tokens_predicted  = data["n_tokens_predicted"];
-        const uint64_t t_tokens_generation = data["t_tokens_generation"];
+        const uint64_t n_tokens_predicted  = data.at("n_tokens_predicted");
+        const uint64_t t_tokens_generation = data.at("t_tokens_generation");
 
-        const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
+        const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
 
         // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
         json all_metrics_def = json {
             {"counter", {{
                     {"name",  "prompt_tokens_total"},
                     {"help",  "Number of prompt tokens processed."},
-                    {"value",  (uint64_t) data["n_prompt_tokens_processed_total"]}
+                    {"value",  (uint64_t) data.at("n_prompt_tokens_processed_total")}
             }, {
                     {"name",  "prompt_seconds_total"},
                     {"help",  "Prompt process time"},
-                    {"value",  (uint64_t) data["t_prompt_processing_total"] / 1.e3}
+                    {"value",  (uint64_t) data.at("t_prompt_processing_total") / 1.e3}
             }, {
                     {"name",  "tokens_predicted_total"},
                     {"help",  "Number of generation tokens processed."},
-                    {"value",  (uint64_t) data["n_tokens_predicted_total"]}
+                    {"value",  (uint64_t) data.at("n_tokens_predicted_total")}
             }, {
                     {"name",  "tokens_predicted_seconds_total"},
                     {"help",  "Predict process time"},
-                    {"value",  (uint64_t) data["t_tokens_generation_total"] / 1.e3}
+                    {"value",  (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
             }}},
             {"gauge", {{
                     {"name",  "prompt_tokens_seconds"},
@@ -3260,15 +3262,15 @@ int main(int argc, char ** argv) {
             },{
                     {"name",  "kv_cache_tokens"},
                     {"help",  "KV-cache tokens."},
-                    {"value",  (uint64_t) data["kv_cache_tokens_count"]}
+                    {"value",  (uint64_t) data.at("kv_cache_tokens_count")}
             },{
                     {"name",  "requests_processing"},
                     {"help",  "Number of request processing."},
-                    {"value",  (uint64_t) data["processing"]}
+                    {"value",  (uint64_t) data.at("processing")}
             },{
                     {"name",  "requests_deferred"},
                     {"help",  "Number of request deferred."},
-                    {"value",  (uint64_t) data["deferred"]}
+                    {"value",  (uint64_t) data.at("deferred")}
             }}}
         };
 
@@ -3279,8 +3281,8 @@ int main(int argc, char ** argv) {
             const auto & metrics_def = el.value();
 
             for (const auto & metric_def : metrics_def) {
-                const std::string name = metric_def["name"];
-                const std::string help = metric_def["help"];
+                const std::string name = metric_def.at("name");
+                const std::string help = metric_def.at("help");
 
                 auto value = json_value(metric_def, "value", 0.);
                 prometheus << "# HELP llamacpp:" << name << " " << help  << "\n"
@@ -3289,7 +3291,7 @@ int main(int argc, char ** argv) {
             }
         }
 
-        const int64_t t_start = data["t_start"];
+        const int64_t t_start = data.at("t_start");
         res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
 
         res.set_content(prometheus.str(), "text/plain; version=0.0.4");
@@ -3298,7 +3300,7 @@ int main(int argc, char ** argv) {
 
     const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
         json request_data = json::parse(req.body);
-        std::string filename = request_data["filename"];
+        std::string filename = request_data.at("filename");
         if (!validate_file_name(filename)) {
             res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
             return;
@@ -3328,7 +3330,7 @@ int main(int argc, char ** argv) {
 
     const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
         json request_data = json::parse(req.body);
-        std::string filename = request_data["filename"];
+        std::string filename = request_data.at("filename");
         if (!validate_file_name(filename)) {
             res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
             return;
@@ -3648,7 +3650,7 @@ int main(int argc, char ** argv) {
         std::vector<llama_token> tokens;
         if (body.count("content") != 0) {
             const bool add_special = json_value(body, "add_special", false);
-            tokens = ctx_server.tokenize(body["content"], add_special);
+            tokens = ctx_server.tokenize(body.at("content"), add_special);
         }
         const json data = format_tokenizer_response(tokens);
         return res.set_content(data.dump(), "application/json; charset=utf-8");
@@ -3660,7 +3662,7 @@ int main(int argc, char ** argv) {
 
         std::string content;
         if (body.count("tokens") != 0) {
-            const std::vector<llama_token> tokens = body["tokens"];
+            const std::vector<llama_token> tokens = body.at("tokens");
             content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
         }
 
@@ -3683,10 +3685,10 @@ int main(int argc, char ** argv) {
         json prompt;
         if (body.count("input") != 0) {
             is_openai = true;
-            prompt = body["input"];
+            prompt = body.at("input");
         } else if (body.count("content") != 0) {
             // with "content", we only support single prompt
-            prompt = std::vector<std::string>{body["content"]};
+            prompt = std::vector<std::string>{body.at("content")};
         } else {
             res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
             return;
@@ -3705,7 +3707,7 @@ int main(int argc, char ** argv) {
             if (!result.error) {
                 if (result.data.count("results")) {
                     // result for multi-task
-                    responses = result.data["results"];
+                    responses = result.data.at("results");
                 } else {
                     // result for single task
                     responses = std::vector<json>{result.data};
index af12f497d327dfba8814cee61b054b470ffe792f..d872b63f537f47767c1530ae39ab223c14f94d71 100644 (file)
@@ -3,6 +3,8 @@
 #include "llama.h"
 #include "common.h"
 
+// Change JSON_ASSERT from assert() to GGML_ASSERT:
+#define JSON_ASSERT GGML_ASSERT
 #include "json.hpp"
 
 #include <string>
@@ -373,11 +375,11 @@ static json oaicompat_completion_params_parse(
     llama_params["top_p"]             = json_value(body,   "top_p",             1.0);
 
     // Apply chat template to the list of messages
-    llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
+    llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
 
     // Handle "stop" field
-    if (body.contains("stop") && body["stop"].is_string()) {
-        llama_params["stop"] = json::array({body["stop"].get<std::string>()});
+    if (body.contains("stop") && body.at("stop").is_string()) {
+        llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
     } else {
         llama_params["stop"] = json_value(body, "stop", json::array());
     }
index b2ce4d260a5a422aeb0a5c3b7fb915ddd0508cf9..c5361b5b8912c65e83339c7d432e6d75e76e3de4 100755 (executable)
@@ -2,6 +2,7 @@
 #undef NDEBUG
 #endif
 
+#include <cassert>
 #include <fstream>
 #include <sstream>
 #include <regex>