]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : remove hack for extra parallel slot (#10187)
authorGeorgi Gerganov <redacted>
Wed, 6 Nov 2024 11:29:01 +0000 (13:29 +0200)
committerGitHub <redacted>
Wed, 6 Nov 2024 11:29:01 +0000 (13:29 +0200)
ggml-ci

examples/server/server.cpp

index f0b89b22cd22da9a0f72d971ba68948e499457c2..1c7f0fd1dd1c3c39ec7011c496e732e1e3ec3e55 100644 (file)
@@ -378,8 +378,8 @@ struct server_queue {
     std::condition_variable condition_tasks;
 
     // callback functions
-    std::function<void(server_task&)> callback_new_task;
-    std::function<void(void)>         callback_update_slots;
+    std::function<void(server_task)> callback_new_task;
+    std::function<void(void)>        callback_update_slots;
 
     // Add a new task to the end of the queue
     int post(server_task task, bool front = false) {
@@ -431,7 +431,7 @@ struct server_queue {
     }
 
     // Register function to process a new task
-    void on_new_task(std::function<void(server_task &)> callback) {
+    void on_new_task(std::function<void(server_task)> callback) {
         callback_new_task = std::move(callback);
     }
 
@@ -481,7 +481,7 @@ struct server_queue {
                 lock.unlock();
 
                 QUE_DBG("processing task, id = %d\n", task.id);
-                callback_new_task(task);
+                callback_new_task(std::move(task));
             }
 
             // all tasks in the current loop is processed, slots data is now ready
@@ -644,17 +644,12 @@ struct server_context {
     bool load_model(const common_params & params_) {
         params = params_;
 
-        // reserve one extra sequence (seq_id == 0) for extra features
-        params.n_parallel += 1;
-
         common_init_result llama_init = common_init_from_params(params);
 
         model = llama_init.model;
         ctx   = llama_init.context;
         loras = llama_init.lora_adapters;
 
-        params.n_parallel -= 1; // but be sneaky about it
-
         if (model == nullptr) {
             SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
             return false;
@@ -1288,16 +1283,16 @@ struct server_context {
 
     void send_embedding(const server_slot & slot, const llama_batch & batch) {
         server_task_result res;
-        res.id       = slot.id_task;
-        res.error    = false;
-        res.stop     = true;
+        res.id    = slot.id_task;
+        res.error = false;
+        res.stop  = true;
 
         const int n_embd = llama_n_embd(model);
 
         std::vector<float> embd_res(n_embd, 0.0f);
 
         for (int i = 0; i < batch.n_tokens; ++i) {
-            if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
+            if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
                 continue;
             }
 
@@ -1332,12 +1327,12 @@ struct server_context {
 
     void send_rerank(const server_slot & slot, const llama_batch & batch) {
         server_task_result res;
-        res.id       = slot.id_task;
-        res.error    = false;
-        res.stop     = true;
+        res.id    = slot.id_task;
+        res.error = false;
+        res.stop  = true;
 
         for (int i = 0; i < batch.n_tokens; ++i) {
-            if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
+            if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
                 continue;
             }
 
@@ -1510,7 +1505,7 @@ struct server_context {
     // Functions to process the task
     //
 
-    void process_single_task(const server_task & task) {
+    void process_single_task(server_task task) {
         switch (task.type) {
             case SERVER_TASK_TYPE_INFERENCE:
                 {
@@ -1646,7 +1641,7 @@ struct server_context {
                     std::string filename = task.data.at("filename");
                     std::string filepath = task.data.at("filepath");
 
-                    const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
+                    const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
 
                     const int64_t t_end = ggml_time_us();
                     const double t_save_ms = (t_end - t_start) / 1000.0;
@@ -1688,7 +1683,7 @@ struct server_context {
 
                     slot->cache_tokens.resize(slot->n_ctx);
                     size_t token_count = 0;
-                    size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
+                    size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
                     if (nread == 0) {
                         slot->cache_tokens.resize(0);
                         send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
@@ -1731,7 +1726,7 @@ struct server_context {
 
                     // Erase token cache
                     const size_t n_erased = slot->cache_tokens.size();
-                    llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
+                    llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
                     slot->cache_tokens.clear();
 
                     server_task_result result;
@@ -1808,8 +1803,8 @@ struct server_context {
 
                 SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
 
-                llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep            , n_keep + n_discard);
-                llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past,        -n_discard);
+                llama_kv_cache_seq_rm (ctx, slot.id, n_keep            , n_keep + n_discard);
+                llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past,        -n_discard);
 
                 if (slot.params.cache_prompt) {
                     for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -1836,7 +1831,7 @@ struct server_context {
 
             slot.i_batch = batch.n_tokens;
 
-            common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true);
+            common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
 
             slot.n_past += 1;
 
@@ -1983,8 +1978,8 @@ struct server_context {
 
                                             const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
 
-                                            llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
-                                            llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1,     kv_shift);
+                                            llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
+                                            llama_kv_cache_seq_add(ctx, slot.id, head_c, -1,     kv_shift);
 
                                             for (size_t i = 0; i < n_match; i++) {
                                                 slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
@@ -2033,9 +2028,9 @@ struct server_context {
                     }
 
                     // keep only the common part
-                    if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
+                    if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
                         // could not partially delete (likely using a non-Transformer model)
-                        llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
+                        llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
 
                         // there is no common part left
                         slot.n_past = 0;
@@ -2048,7 +2043,7 @@ struct server_context {
 
                     // add prompt tokens for processing in the current batch
                     while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
-                        common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
+                        common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
 
                         if (slot.params.cache_prompt) {
                             slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);