]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : improve "prompt" handling (#7847)
authorGeorgi Gerganov <redacted>
Mon, 10 Jun 2024 11:59:55 +0000 (14:59 +0300)
committerGitHub <redacted>
Mon, 10 Jun 2024 11:59:55 +0000 (14:59 +0300)
examples/server/server.cpp

index 6ffaa8d9fe6374dc9d750d9e53d0abab3b543ec7..80714fa58360b35ad9405d20e5080f40a1972ff3 100644 (file)
@@ -147,7 +147,7 @@ struct server_slot {
     int32_t n_prompt_tokens           = 0;
     int32_t n_prompt_tokens_processed = 0;
 
-    json prompt;
+    std::string prompt;
 
     // when a task is submitted, we first tokenize the prompt and store it here
     std::vector<llama_token> prompt_tokens;
@@ -822,13 +822,8 @@ struct server_context {
                     continue;
                 }
 
-                // skip the slot if it does not contains prompt
-                if (!slot.prompt.is_string()) {
-                    continue;
-                }
-
                 // current slot's prompt
-                std::string slot_prompt = slot.prompt.get<std::string>();
+                std::string slot_prompt = slot.prompt;
 
                 // length of the current slot's prompt
                 int slot_prompt_len = slot_prompt.size();
@@ -958,13 +953,16 @@ struct server_context {
         if (!task.infill) {
             const auto & prompt = data.find("prompt");
             if (prompt == data.end()) {
-                send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
+                send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
                 return false;
-            } else {
-                slot.prompt = *prompt;
             }
-            if (slot.prompt.is_array() && slot.prompt.size() == 0) {
-                send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST);
+
+            if (prompt->is_string()) {
+                slot.prompt = prompt->get<std::string>();
+            } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
+                slot.prompt = prompt->at(0).get<std::string>();
+            } else {
+                send_error(task, "\"prompt\" must be a string or an array of strings", ERROR_TYPE_INVALID_REQUEST);
                 return false;
             }
         }
@@ -1582,14 +1580,18 @@ struct server_context {
         switch (task.type) {
             case SERVER_TASK_TYPE_COMPLETION:
                 {
-                    int id_slot        = json_value(task.data, "id_slot", -1);
-                    std::string prompt = json_value(task.data, "prompt", std::string());
+                    const int id_slot = json_value(task.data, "id_slot", -1);
 
                     server_slot * slot;
 
                     if (id_slot != -1) {
                         slot = get_slot_by_id(id_slot);
                     } else {
+                        std::string prompt;
+                        if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
+                            json_value(task.data, "prompt", std::string());
+                        }
+
                         slot = get_available_slot(prompt);
                     }