]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : restore numeric prompts (#7883)
authorGeorgi Gerganov <redacted>
Wed, 12 Jun 2024 11:42:29 +0000 (14:42 +0300)
committerGitHub <redacted>
Wed, 12 Jun 2024 11:42:29 +0000 (14:42 +0300)
examples/server/server.cpp

index 80714fa58360b35ad9405d20e5080f40a1972ff3..919078f2bd920553c75944e81954167dbf0dc083 100644 (file)
@@ -147,7 +147,7 @@ struct server_slot {
     int32_t n_prompt_tokens           = 0;
     int32_t n_prompt_tokens_processed = 0;
 
-    std::string prompt;
+    json prompt; // can be either a string, array of strings or array of token ids
 
     // when a task is submitted, we first tokenize the prompt and store it here
     std::vector<llama_token> prompt_tokens;
@@ -822,8 +822,13 @@ struct server_context {
                     continue;
                 }
 
+                // skip the slot if it does not contains prompt
+                if (!slot.prompt.is_string()) {
+                    continue;
+                }
+
                 // current slot's prompt
-                std::string slot_prompt = slot.prompt;
+                std::string slot_prompt = slot.prompt.get<std::string>();
 
                 // length of the current slot's prompt
                 int slot_prompt_len = slot_prompt.size();
@@ -957,12 +962,12 @@ struct server_context {
                 return false;
             }
 
-            if (prompt->is_string()) {
-                slot.prompt = prompt->get<std::string>();
-            } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
-                slot.prompt = prompt->at(0).get<std::string>();
+            if ((prompt->is_string()) ||
+                (prompt->is_array() &&  prompt->size() == 1 && prompt->at(0).is_string()) ||
+                (prompt->is_array() && !prompt->empty()     && prompt->at(0).is_number_integer())) {
+                slot.prompt = *prompt;
             } else {
-                send_error(task, "\"prompt\" must be a string or an array of strings", ERROR_TYPE_INVALID_REQUEST);
+                send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
                 return false;
             }
         }