]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : fix smart selection of available slot (#10120)
authorsasha0552 <redacted>
Fri, 1 Nov 2024 13:33:14 +0000 (13:33 +0000)
committerGitHub <redacted>
Fri, 1 Nov 2024 13:33:14 +0000 (14:33 +0100)
* Fix smart selection of available slot

* minor fix

* replace vectors of tokens with shorthands

examples/server/server.cpp
examples/server/utils.hpp

index f914ff88caee6c0a6e823f8d8e6191254cab8a9e..54cdb4b72e64ff6affef02fcc53fde5836be3b96 100644 (file)
@@ -725,12 +725,12 @@ struct server_context {
         return nullptr;
     }
 
-    server_slot * get_available_slot(const std::string & prompt) {
+    server_slot * get_available_slot(const server_task & task) {
         server_slot * ret = nullptr;
 
         // find the slot that has at least n% prompt similarity
-        if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
-            int max_lcp_len = 0;
+        if (ret == nullptr && slot_prompt_similarity != 0.0f) {
+            int max_lcs_len = 0;
             float similarity = 0;
 
             for (server_slot & slot : slots) {
@@ -740,25 +740,25 @@ struct server_context {
                 }
 
                 // skip the slot if it does not contains cached tokens
-                if (slot.prompt_tokens.empty()) {
+                if (slot.cache_tokens.empty()) {
                     continue;
                 }
 
-                // length of the Longest Common Prefix between the current slot's prompt and the input prompt
-                int lcp_len = longest_common_prefix(slot.cache_tokens, slot.prompt_tokens);
+                // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
+                int lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
 
-                // fraction of the common substring length compared to the current slot's prompt length
-                similarity = static_cast<float>(lcp_len) / static_cast<int>(slot.prompt_tokens.size());
+                // fraction of the common subsequence length compared to the current slot's prompt length
+                similarity = static_cast<float>(lcs_len) / static_cast<int>(slot.cache_tokens.size());
 
                 // select the current slot if the criteria match
-                if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
-                    max_lcp_len = lcp_len;
+                if (lcs_len > max_lcs_len && similarity > slot_prompt_similarity) {
+                    max_lcs_len = lcs_len;
                     ret = &slot;
                 }
             }
 
             if (ret != nullptr) {
-                SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
+                SLT_DBG(*ret, "selected slot by lcs similarity, max_lcs_len = %d, similarity = %f\n", max_lcs_len, similarity);
             }
         }
 
@@ -1514,18 +1514,7 @@ struct server_context {
                 {
                     const int id_slot = json_value(task.data, "id_slot", -1);
 
-                    server_slot * slot;
-
-                    if (id_slot != -1) {
-                        slot = get_slot_by_id(id_slot);
-                    } else {
-                        std::string prompt;
-                        if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
-                            prompt = json_value(task.data, "prompt", std::string());
-                        }
-
-                        slot = get_available_slot(prompt);
-                    }
+                    server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
 
                     if (slot == nullptr) {
                         // if no slot is available, we defer this task for processing later
index 58f5a5684ac115fb6300d6a88a954bc69f8ac085..871a17a4f617ab46648a09873e04bbcfccc703e1 100644 (file)
@@ -439,18 +439,60 @@ static std::string gen_chatcmplid() {
 // other common utils
 //
 
-static size_t longest_common_prefix(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
+static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) {
     size_t i;
     for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
 
     return i;
 }
 
-static size_t longest_common_prefix(const std::string & a, const std::string & b) {
-    size_t i;
-    for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
+static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) {
+    // check for empty sequences
+    if (a.empty() || b.empty()) {
+        return 0;
+    }
 
-    return i;
+    // get the lengths of the input sequences
+    int a_len = a.size();
+    int b_len = b.size();
+
+    // initialize the maximum length of the longest common subsequence (LCS)
+    int max_length = 0;
+
+    // use two rows instead of a 2D matrix to optimize space
+    std::vector<int> prev_row(b_len + 1, 0);
+    std::vector<int> curr_row(b_len + 1, 0);
+
+    // iterate through the elements of a
+    for (int i = 1; i <= a_len; i++) {
+        // iterate through the elements of b
+        for (int j = 1; j <= b_len; j++) {
+            // if elements at the current positions match
+            if (a[i - 1] == b[j - 1]) {
+                // if it's the first element of either sequences, set LCS length to 1
+                if (i == 1 || j == 1) {
+                    curr_row[j] = 1;
+                } else {
+                    // increment LCS length by 1 compared to the previous element
+                    curr_row[j] = prev_row[j - 1] + 1;
+                }
+
+                // update max_length if necessary
+                if (curr_row[j] > max_length) {
+                    max_length = curr_row[j];
+                }
+            } else {
+                // reset LCS length if elements don't match
+                curr_row[j] = 0;
+            }
+        }
+
+        // update the previous row for the next iteration
+        prev_row = curr_row;
+    }
+
+    // return the maximum length of the LCS
+    return max_length;
 }
 
 static bool ends_with(const std::string & str, const std::string & suffix) {