]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
fix system prompt handling (#7153)
authorXuan Son Nguyen <redacted>
Sat, 11 May 2024 15:28:10 +0000 (17:28 +0200)
committerGitHub <redacted>
Sat, 11 May 2024 15:28:10 +0000 (17:28 +0200)
examples/server/server.cpp

index 55c1d41298cd662b6b9a546654ffab3b29bc15d8..ceaeb1f76dc3dec4cc1321dc0dc858663f2bbb16 100644 (file)
@@ -651,9 +651,6 @@ struct server_context {
     std::string              system_prompt;
     std::vector<llama_token> system_tokens;
 
-    std::string name_user;      // this should be the antiprompt
-    std::string name_assistant;
-
     // slots / clients
     std::vector<server_slot> slots;
     json default_generation_settings_for_props;
@@ -1100,15 +1097,11 @@ struct server_context {
         system_need_update = false;
     }
 
-    void system_prompt_set(const json & sys_props) {
-        system_prompt  = sys_props.value("prompt", "");
-        name_user      = sys_props.value("anti_prompt", "");
-        name_assistant = sys_props.value("assistant_name", "");
+    bool system_prompt_set(const std::string & sys_prompt) {
+        system_prompt = sys_prompt;
 
         LOG_VERBOSE("system prompt process", {
             {"system_prompt",  system_prompt},
-            {"name_user",      name_user},
-            {"name_assistant", name_assistant},
         });
 
         // release all slots
@@ -1117,6 +1110,7 @@ struct server_context {
         }
 
         system_need_update = true;
+        return true;
     }
 
     bool process_token(completion_token_output & result, server_slot & slot) {
@@ -1536,7 +1530,8 @@ struct server_context {
                     }
 
                     if (task.data.contains("system_prompt")) {
-                        system_prompt_set(task.data.at("system_prompt"));
+                        std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
+                        system_prompt_set(sys_prompt);
 
                         for (server_slot & slot : slots) {
                             slot.n_past    = 0;
@@ -2920,7 +2915,7 @@ int main(int argc, char ** argv) {
     server_params_parse(argc, argv, sparams, params);
 
     if (!sparams.system_prompt.empty()) {
-        ctx_server.system_prompt_set(json::parse(sparams.system_prompt));
+        ctx_server.system_prompt_set(sparams.system_prompt);
     }
 
     if (params.model_alias == "unknown") {
@@ -3409,8 +3404,7 @@ int main(int argc, char ** argv) {
     const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
         res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
         json data = {
-            { "user_name",                   ctx_server.name_user.c_str() },
-            { "assistant_name",              ctx_server.name_assistant.c_str() },
+            { "system_prompt",               ctx_server.system_prompt.c_str() },
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params.n_parallel }
         };