]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : various fixes for the prompt field in /completion (#5300)
authorNiall Coates <redacted>
Tue, 6 Feb 2024 08:16:23 +0000 (08:16 +0000)
committerGitHub <redacted>
Tue, 6 Feb 2024 08:16:23 +0000 (10:16 +0200)
server : fix deadlock when prompt array contains strings and numbers

server : removed an unnecessary generation when generating multi-prompts

server : removed an unnecessary assert

examples/server/server.cpp

index 8000fee5c90d79c1ef1b413b3f80a7b4360b9584..fc7e723a13573c42c8856b3210a782d7c48e2b14 100644 (file)
@@ -1163,13 +1163,30 @@ struct llama_server_context
         task.multitask_id = multitask_id;
 
         // when a completion task's prompt array is not a singleton, we split it into multiple requests
-        if (task.data.count("prompt") && task.data.at("prompt").size() > 1)
-        {
-            split_multiprompt_task(task_id, task);
-        }
-
         // otherwise, it's a single-prompt task, we actually queue it
-        queue_tasks.post(task);
+        // if there's numbers in the prompt array it will be treated as an array of tokens
+        if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
+            bool numbers = false;
+            for (const auto& e : task.data.at("prompt")) {
+                if (e.is_number()) {
+                    numbers = true;
+                    break;
+                }
+            }
+
+            // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
+            // it will completely stall the server. I don't know where the bug for this is.
+            //
+            // if there are numbers, it needs to be treated like a single prompt,
+            // queue_tasks handles a mix of strings and numbers just fine.
+            if (numbers) {
+                queue_tasks.post(task);
+            } else {
+                split_multiprompt_task(task_id, task);
+            }
+        } else {
+            queue_tasks.post(task);
+        }
     }
 
     // for multiple images processing
@@ -1251,7 +1268,10 @@ struct llama_server_context
     void split_multiprompt_task(int multitask_id, task_server& multiprompt_task)
     {
         int prompt_count = multiprompt_task.data.at("prompt").size();
-        assert(prompt_count > 1);
+        if (prompt_count <= 1) {
+            send_error(multiprompt_task, "error while handling multiple prompts");
+            return;
+        }
 
         // generate all the ID for subtask
         std::vector<int> subtask_ids(prompt_count);