]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : re-enable completion and embedded at the same time (#3876)
authorAdrian Hesketh <redacted>
Wed, 1 Nov 2023 09:28:28 +0000 (09:28 +0000)
committerGitHub <redacted>
Wed, 1 Nov 2023 09:28:28 +0000 (11:28 +0200)
.gitignore
examples/server/server.cpp

index 545c2872632234a1a7059b1ca64a4410efd80e90..5d7c5479ef67aeb4d18f73df5fc57cfb63149d1c 100644 (file)
@@ -15,6 +15,7 @@
 .DS_Store
 .build/
 .cache/
+.ccls-cache/
 .direnv/
 .envrc
 .swiftpm
index c163c7f8ec0dd12c306794ca8128c0184c997eb7..47ae0d55856cf8d8784eab3df65f807923987757 100644 (file)
@@ -149,6 +149,7 @@ struct task_server {
     task_type type;
     json data;
     bool infill_mode = false;
+    bool embedding_mode = false;
 };
 
 struct task_result {
@@ -371,6 +372,7 @@ struct llama_client_slot
     std::vector<completion_token_output> generated_token_probs;
 
     bool infill = false;
+    bool embedding = false;
     bool has_next_token = true;
     bool truncated = false;
     bool stopped_eos = false;
@@ -1244,13 +1246,14 @@ struct llama_server_context
         queue_results.push_back(res);
     }
 
-    int request_completion(json data, bool infill)
+    int request_completion(json data, bool infill, bool embedding)
     {
         std::lock_guard<std::mutex> lock(mutex_tasks);
         task_server task;
         task.id = id_gen++;
         task.data = data;
         task.infill_mode = infill;
+        task.embedding_mode = embedding;
         task.type = COMPLETION_TASK;
         queue_tasks.push_back(task);
         return task.id;
@@ -1376,7 +1379,7 @@ struct llama_server_context
                     {
                         LOG_TEE("slot unavailable\n");
                         // send error result
-                        send_error(task.id, "slot unavaliable");
+                        send_error(task.id, "slot unavailable");
                         return;
                     }
 
@@ -1388,6 +1391,7 @@ struct llama_server_context
                     slot->reset();
 
                     slot->infill = task.infill_mode;
+                    slot->embedding = task.embedding_mode;
                     slot->task_id = task.id;
 
                     if (!launch_slot_with_data(slot, task.data))
@@ -1695,7 +1699,7 @@ struct llama_server_context
                 }
 
                 // prompt evaluated for embedding
-                if (params.embedding)
+                if (slot.embedding)
                 {
                     send_embedding(slot);
                     slot.release();
@@ -2274,7 +2278,7 @@ int main(int argc, char **argv)
     svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
             {
                 json data = json::parse(req.body);
-                const int task_id = llama.request_completion(data, false);
+                const int task_id = llama.request_completion(data, false, false);
                 if (!json_value(data, "stream", false)) {
                     std::string completion_text;
                     task_result result = llama.next_result(task_id);
@@ -2329,7 +2333,7 @@ int main(int argc, char **argv)
     svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
             {
                 json data = json::parse(req.body);
-                const int task_id = llama.request_completion(data, true);
+                const int task_id = llama.request_completion(data, true, false);
                 if (!json_value(data, "stream", false)) {
                     std::string completion_text;
                     task_result result = llama.next_result(task_id);
@@ -2433,7 +2437,7 @@ int main(int argc, char **argv)
                 {
                     prompt = "";
                 }
-                const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false);
+                const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
                 task_result result = llama.next_result(task_id);
                 return res.set_content(result.result_json.dump(), "application/json");
             });