]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : smart slot selection using Longest Common Prefix (#7728)
authorsasha0552 <redacted>
Sat, 8 Jun 2024 07:50:31 +0000 (07:50 +0000)
committerGitHub <redacted>
Sat, 8 Jun 2024 07:50:31 +0000 (10:50 +0300)
* server : Smart selection of available slot using Longest Common Substring

* add usage

* remove trailing whitespaces

* Use Longest Common Prefix (LCP) instead of LCS

* Rename argument

common/common.cpp
common/common.h
examples/server/server.cpp
examples/server/utils.hpp

index cdcb352b5a8aed395a5c21f1393db7ae31c0452a..d2a8bb69e728fb371329c215fbfaabee452890d5 100644 (file)
@@ -1491,6 +1491,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
         params.chat_template = argv[i];
         return true;
     }
+    if (arg == "--slot-prompt-similarity" || arg == "-sps") {
+        if (++i >= argc) {
+            invalid_param = true;
+            return true;
+        }
+        params.slot_prompt_similarity = std::stof(argv[i]);
+        return true;
+    }
     if (arg == "-pps") {
         params.is_pp_shared = true;
         return true;
@@ -1913,6 +1921,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
                                                                         "set custom jinja chat template (default: template taken from model's metadata)\n"
                                                                         "only commonly used templates are accepted:\n"
                                                                         "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
+    options.push_back({ "server",      "-sps,  --slot-prompt-similarity SIMILARITY",
+                                                                        "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
 
 #ifndef LOG_DISABLE_LOGS
     options.push_back({ "logging" });
index 35f5311e10fe1abee425d9372624349f6b9bfae9..038f9084f342451d80463344f14ba0936cd69468 100644 (file)
@@ -203,6 +203,8 @@ struct gpt_params {
 
     std::string slot_save_path;
 
+    float slot_prompt_similarity = 0.5f;
+
     // batched-bench params
     bool is_pp_shared = false;
 
index 528220607a4f62af2e24e559794125aebcb1ad0e..6ffaa8d9fe6374dc9d750d9e53d0abab3b543ec7 100644 (file)
@@ -647,6 +647,9 @@ struct server_context {
 
     server_metrics metrics;
 
+    // Necessary similarity of prompt for slot selection
+    float slot_prompt_similarity = 0.0f;
+
     ~server_context() {
         if (ctx) {
             llama_free(ctx);
@@ -795,24 +798,88 @@ struct server_context {
         return prompt_tokens;
     }
 
-    server_slot * get_slot(int id) {
-        int64_t t_last = ggml_time_us();
-
-        server_slot * last_used = nullptr;
-
+    server_slot * get_slot_by_id(int id) {
         for (server_slot & slot : slots) {
-            if (slot.id == id && slot.available()) {
+            if (slot.id == id) {
                 return &slot;
             }
+        }
+
+        return nullptr;
+    }
+
+    server_slot * get_available_slot(const std::string & prompt) {
+        server_slot * ret = nullptr;
+
+        // find the slot that has at least n% prompt similarity
+        if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
+            int max_lcp_len = 0;
+            float similarity = 0;
+
+            for (server_slot & slot : slots) {
+                // skip the slot if it is not available
+                if (!slot.available()) {
+                    continue;
+                }
+
+                // skip the slot if it does not contains prompt
+                if (!slot.prompt.is_string()) {
+                    continue;
+                }
+
+                // current slot's prompt
+                std::string slot_prompt = slot.prompt.get<std::string>();
+
+                // length of the current slot's prompt
+                int slot_prompt_len = slot_prompt.size();
+
+                // length of the Longest Common Prefix between the current slot's prompt and the input prompt
+                int lcp_len = common_part(slot_prompt, prompt);
+
+                // fraction of the common substring length compared to the current slot's prompt length
+                similarity = static_cast<float>(lcp_len) / slot_prompt_len;
+
+                // select the current slot if the criteria match
+                if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
+                    max_lcp_len = lcp_len;
+                    ret = &slot;
+                }
+            }
 
-            // among all available slots, find the one that has been least recently used
-            if (slot.available() && slot.t_last_used < t_last) {
-                last_used = &slot;
-                t_last = slot.t_last_used;
+            if (ret != nullptr) {
+                LOG_VERBOSE("selected slot by lcp similarity", {
+                    {"id_slot", ret->id},
+                    {"max_lcp_len", max_lcp_len},
+                    {"similarity", similarity},
+                });
             }
         }
 
-        return last_used;
+        // find the slot that has been least recently used
+        if (ret == nullptr) {
+            int64_t t_last = ggml_time_us();
+            for (server_slot & slot : slots) {
+                // skip the slot if it is not available
+                if (!slot.available()) {
+                    continue;
+                }
+
+                // select the current slot if the criteria match
+                if (slot.t_last_used < t_last) {
+                    t_last = slot.t_last_used;
+                    ret = &slot;
+                }
+            }
+
+            if (ret != nullptr) {
+                LOG_VERBOSE("selected slot by lru", {
+                    {"id_slot", ret->id},
+                    {"t_last", t_last},
+                });
+            }
+        }
+
+        return ret;
     }
 
     bool launch_slot_with_task(server_slot & slot, const server_task & task) {
@@ -1515,13 +1582,29 @@ struct server_context {
         switch (task.type) {
             case SERVER_TASK_TYPE_COMPLETION:
                 {
-                    server_slot * slot = get_slot(json_value(task.data, "id_slot", -1));
+                    int id_slot        = json_value(task.data, "id_slot", -1);
+                    std::string prompt = json_value(task.data, "prompt", std::string());
+
+                    server_slot * slot;
+
+                    if (id_slot != -1) {
+                        slot = get_slot_by_id(id_slot);
+                    } else {
+                        slot = get_available_slot(prompt);
+                    }
+
                     if (slot == nullptr) {
                         // if no slot is available, we defer this task for processing later
                         LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
                         queue_tasks.defer(task);
                         break;
                     }
+                    if (!slot->available()) {
+                        // if requested slot is unavailable, we defer this task for processing later
+                        LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
+                        queue_tasks.defer(task);
+                        break;
+                    }
 
                     if (task.data.contains("system_prompt")) {
                         std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
@@ -1638,11 +1721,17 @@ struct server_context {
             case SERVER_TASK_TYPE_SLOT_SAVE:
                 {
                     int id_slot = task.data.at("id_slot");
-                    server_slot * slot = get_slot(id_slot);
+                    server_slot * slot = get_slot_by_id(id_slot);
                     if (slot == nullptr) {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
                         break;
                     }
+                    if (!slot->available()) {
+                        // if requested slot is unavailable, we defer this task for processing later
+                        LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
+                        queue_tasks.defer(task);
+                        break;
+                    }
 
                     const size_t token_count = slot->cache_tokens.size();
                     const int64_t t_start = ggml_time_us();
@@ -1673,11 +1762,17 @@ struct server_context {
             case SERVER_TASK_TYPE_SLOT_RESTORE:
                 {
                     int id_slot = task.data.at("id_slot");
-                    server_slot * slot = get_slot(id_slot);
+                    server_slot * slot = get_slot_by_id(id_slot);
                     if (slot == nullptr) {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
                         break;
                     }
+                    if (!slot->available()) {
+                        // if requested slot is unavailable, we defer this task for processing later
+                        LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
+                        queue_tasks.defer(task);
+                        break;
+                    }
 
                     const int64_t t_start = ggml_time_us();
 
@@ -1715,11 +1810,17 @@ struct server_context {
             case SERVER_TASK_TYPE_SLOT_ERASE:
                 {
                     int id_slot = task.data.at("id_slot");
-                    server_slot * slot = get_slot(id_slot);
+                    server_slot * slot = get_slot_by_id(id_slot);
                     if (slot == nullptr) {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
                         break;
                     }
+                    if (!slot->available()) {
+                        // if requested slot is unavailable, we defer this task for processing later
+                        LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
+                        queue_tasks.defer(task);
+                        break;
+                    }
 
                     // Erase token cache
                     const size_t n_erased = slot->cache_tokens.size();
@@ -2467,6 +2568,9 @@ int main(int argc, char ** argv) {
         log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded";
     }
 
+    // Necessary similarity of prompt for slot selection
+    ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
+
     // load the model
     if (!ctx_server.load_model(params)) {
         state.store(SERVER_STATE_ERROR);
index b7bfb41d35edc497ef58c8d17ed1cf0f64c4cc74..63fde9c9faabe3cd68ce1692245706c6aa5dcce7 100644 (file)
@@ -253,6 +253,13 @@ static size_t common_part(const std::vector<llama_token> & a, const std::vector<
     return i;
 }
 
+static size_t common_part(const std::string & a, const std::string & b) {
+    size_t i;
+    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
+
+    return i;
+}
+
 static bool ends_with(const std::string & str, const std::string & suffix) {
     return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
 }