]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
server : simplify state machine for slot (#9283)
authorXuan Son Nguyen <redacted>
Fri, 6 Sep 2024 21:21:29 +0000 (23:21 +0200)
committerGitHub <redacted>
Fri, 6 Sep 2024 21:21:29 +0000 (23:21 +0200)
* server : simplify state machine for slot

* add SLOT_STATE_DONE_PROMPT

* pop_deferred_task

* add missing notify_one

* fix passkey test

* metrics : add n_busy_slots_per_decode

* fix test step

* add test

* maybe fix AddressSanitizer?

* fix deque ?

* missing lock

* pop_deferred_task: also notify

* Update examples/server/server.cpp

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: Georgi Gerganov <redacted>
examples/server/server.cpp
examples/server/tests/features/parallel.feature
examples/server/tests/features/passkey.feature
examples/server/tests/features/steps/steps.py

index 5f2238d61afeb754c68e2593b5301930967ccf37..cc65c57ab723c9dc0abaddb6506051e38d7dff01 100644 (file)
@@ -50,15 +50,12 @@ enum stop_type {
     STOP_TYPE_PARTIAL,
 };
 
+// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
 enum slot_state {
     SLOT_STATE_IDLE,
-    SLOT_STATE_PROCESSING,
-};
-
-enum slot_command {
-    SLOT_COMMAND_NONE,
-    SLOT_COMMAND_LOAD_PROMPT,
-    SLOT_COMMAND_RELEASE,
+    SLOT_STATE_PROCESSING_PROMPT,
+    SLOT_STATE_DONE_PROMPT,
+    SLOT_STATE_GENERATING,
 };
 
 enum server_state {
@@ -135,7 +132,6 @@ struct server_slot {
     struct slot_params params;
 
     slot_state state = SLOT_STATE_IDLE;
-    slot_command command = SLOT_COMMAND_NONE;
 
     // used to determine the slot that has been used the longest
     int64_t t_last_used = -1;
@@ -194,6 +190,8 @@ struct server_slot {
     double t_prompt_processing; // ms
     double t_token_generation; // ms
 
+    std::function<void(int)> callback_on_release;
+
     void reset() {
         n_prompt_tokens    = 0;
         generated_text     = "";
@@ -228,25 +226,28 @@ struct server_slot {
         return n_remaining > 0; // no budget
     }
 
-    bool available() const {
-        return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
-    }
-
     bool is_processing() const {
-        return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING;
+        return state != SLOT_STATE_IDLE;
     }
 
     void add_token_string(const completion_token_output & token) {
-        if (command == SLOT_COMMAND_RELEASE) {
+        if (!is_processing()) {
             return;
         }
         generated_token_probs.push_back(token);
     }
 
     void release() {
-        if (state == SLOT_STATE_PROCESSING) {
+        if (is_processing()) {
             t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
-            command = SLOT_COMMAND_RELEASE;
+            state = SLOT_STATE_IDLE;
+            LOG_INFO("slot released", {
+                {"id_slot",   id},
+                {"id_task",   id_task},
+                {"n_past",    n_past},
+                {"truncated", truncated},
+            });
+            callback_on_release(id);
         }
     }
 
@@ -353,6 +354,9 @@ struct server_metrics {
     uint64_t n_tokens_predicted  = 0;
     uint64_t t_tokens_generation = 0;
 
+    uint64_t n_decode_total     = 0;
+    uint64_t n_busy_slots_total = 0;
+
     void init() {
         t_start = ggml_time_us();
     }
@@ -371,6 +375,15 @@ struct server_metrics {
         t_tokens_generation_total  += slot.t_token_generation;
     }
 
+    void on_decoded(const std::vector<server_slot> & slots) {
+        n_decode_total++;
+        for (const auto & slot : slots) {
+            if (slot.is_processing()) {
+                n_busy_slots_total++;
+            }
+        }
+    }
+
     void reset_bucket() {
         n_prompt_tokens_processed = 0;
         t_prompt_processing       = 0;
@@ -432,6 +445,7 @@ struct server_queue {
     void defer(server_task task) {
         std::unique_lock<std::mutex> lock(mutex_tasks);
         queue_tasks_deferred.push_back(std::move(task));
+        condition_tasks.notify_one();
     }
 
     // Get the next id for creating a new task
@@ -452,14 +466,14 @@ struct server_queue {
         callback_update_slots = std::move(callback);
     }
 
-    // Call when the state of one slot is changed
-    void notify_slot_changed() {
-        // move deferred tasks back to main loop
+    // Call when the state of one slot is changed, it will move one task from deferred to main queue
+    void pop_deferred_task() {
         std::unique_lock<std::mutex> lock(mutex_tasks);
-        for (auto & task : queue_tasks_deferred) {
-            queue_tasks.push_back(std::move(task));
+        if (!queue_tasks_deferred.empty()) {
+            queue_tasks.emplace_back(std::move(queue_tasks_deferred.front()));
+            queue_tasks_deferred.pop_front();
         }
-        queue_tasks_deferred.clear();
+        condition_tasks.notify_one();
     }
 
     // end the start_loop routine
@@ -489,7 +503,7 @@ struct server_queue {
                     break;
                 }
                 server_task task = queue_tasks.front();
-                queue_tasks.erase(queue_tasks.begin());
+                queue_tasks.pop_front();
                 lock.unlock();
                 LOG_VERBOSE("callback_new_task", {{"id_task", task.id}});
                 callback_new_task(task);
@@ -717,6 +731,10 @@ struct server_context {
 
             slot.sparams = params.sparams;
 
+            slot.callback_on_release = [this](int) {
+                queue_tasks.pop_deferred_task();
+            };
+
             slot.reset();
 
             slots.push_back(slot);
@@ -798,7 +816,7 @@ struct server_context {
 
             for (server_slot & slot : slots) {
                 // skip the slot if it is not available
-                if (!slot.available()) {
+                if (slot.is_processing()) {
                     continue;
                 }
 
@@ -840,7 +858,7 @@ struct server_context {
             int64_t t_last = ggml_time_us();
             for (server_slot & slot : slots) {
                 // skip the slot if it is not available
-                if (!slot.available()) {
+                if (slot.is_processing()) {
                     continue;
                 }
 
@@ -1078,7 +1096,7 @@ struct server_context {
             }
         }
 
-        slot.command = SLOT_COMMAND_LOAD_PROMPT;
+        slot.state = SLOT_STATE_PROCESSING_PROMPT;
         slot.prompt_tokens.clear();
 
         LOG_INFO("slot is processing task", {
@@ -1622,7 +1640,7 @@ struct server_context {
                         queue_tasks.defer(task);
                         break;
                     }
-                    if (!slot->available()) {
+                    if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
                         queue_tasks.defer(task);
@@ -1728,6 +1746,9 @@ struct server_context {
                         { "n_tokens_predicted",              metrics.n_tokens_predicted},
                         { "t_tokens_generation",             metrics.t_tokens_generation},
 
+                        { "n_decode_total",                  metrics.n_decode_total},
+                        { "n_busy_slots_total",              metrics.n_busy_slots_total},
+
                         { "kv_cache_tokens_count",           llama_get_kv_cache_token_count(ctx)},
                         { "kv_cache_used_cells",             llama_get_kv_cache_used_cells(ctx)},
 
@@ -1747,7 +1768,7 @@ struct server_context {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
                         break;
                     }
-                    if (!slot->available()) {
+                    if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
                         queue_tasks.defer(task);
@@ -1788,7 +1809,7 @@ struct server_context {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
                         break;
                     }
-                    if (!slot->available()) {
+                    if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
                         queue_tasks.defer(task);
@@ -1836,7 +1857,7 @@ struct server_context {
                         send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
                         break;
                     }
-                    if (!slot->available()) {
+                    if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
                         queue_tasks.defer(task);
@@ -1876,33 +1897,12 @@ struct server_context {
             system_prompt_update();
         }
 
-        // release slots
-        for (auto & slot : slots) {
-            if (slot.command == SLOT_COMMAND_RELEASE) {
-                slot.state       = SLOT_STATE_IDLE;
-                slot.command     = SLOT_COMMAND_NONE;
-                slot.t_last_used = ggml_time_us();
-
-                LOG_INFO("slot released", {
-                    {"id_slot",         slot.id},
-                    {"id_task",         slot.id_task},
-                    {"n_ctx",           n_ctx},
-                    {"n_past",          slot.n_past},
-                    {"n_system_tokens", system_tokens.size()},
-                    {"n_cache_tokens",  slot.cache_tokens.size()},
-                    {"truncated",       slot.truncated}
-                });
-
-                queue_tasks.notify_slot_changed();
-            }
-        }
-
         // check if all slots are idle
         {
             bool all_idle = true;
 
             for (auto & slot : slots) {
-                if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
+                if (slot.is_processing()) {
                     all_idle = false;
                     break;
                 }
@@ -1973,7 +1973,7 @@ struct server_context {
 
         // frist, add sampled tokens from any ongoing sequences
         for (auto & slot : slots) {
-            if (slot.state == SLOT_STATE_IDLE) {
+            if (slot.state != SLOT_STATE_GENERATING) {
                 continue;
             }
 
@@ -2015,7 +2015,7 @@ struct server_context {
         if (params.cont_batching || batch.n_tokens == 0) {
             for (auto & slot : slots) {
                 // this slot still has a prompt to be processed
-                if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
+                if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
                     auto & prompt_tokens = slot.prompt_tokens;
 
                     // we haven't tokenized the prompt yet - do it now:
@@ -2083,8 +2083,6 @@ struct server_context {
                                 {"id_task", slot.id_task}
                             });
 
-                            slot.state = SLOT_STATE_PROCESSING;
-                            slot.command = SLOT_COMMAND_NONE;
                             slot.release();
                             slot.print_timings();
                             send_final_response(slot);
@@ -2094,8 +2092,6 @@ struct server_context {
                         if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
                             // this prompt is too large to process - discard it
                             if (slot.n_prompt_tokens > n_ubatch) {
-                                slot.state = SLOT_STATE_PROCESSING;
-                                slot.command = SLOT_COMMAND_NONE;
                                 slot.release();
                                 send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
                                 continue;
@@ -2253,10 +2249,9 @@ struct server_context {
                         {"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
                     });
 
-                    // entire prompt has been processed - start decoding new tokens
+                    // entire prompt has been processed
                     if (slot.n_past == slot.n_prompt_tokens) {
-                        slot.state   = SLOT_STATE_PROCESSING;
-                        slot.command = SLOT_COMMAND_NONE;
+                        slot.state = SLOT_STATE_DONE_PROMPT;
 
                         GGML_ASSERT(batch.n_tokens > 0);
 
@@ -2338,18 +2333,17 @@ struct server_context {
             };
 
             const int ret = llama_decode(ctx, batch_view);
+            metrics.on_decoded(slots);
 
             if (ret != 0) {
                 if (n_batch == 1 || ret < 0) {
                     // if you get here, it means the KV cache is full - try increasing it via the context size
                     LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
-                        {"i",   i},
-                        {"n_batch",  ret},
-                        {"ret",   ret},
+                        {"i",       i},
+                        {"n_batch", n_batch},
+                        {"ret",     ret},
                     });
                     for (auto & slot : slots) {
-                        slot.state = SLOT_STATE_PROCESSING;
-                        slot.command = SLOT_COMMAND_NONE;
                         slot.release();
                         send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
                     }
@@ -2361,24 +2355,31 @@ struct server_context {
                 i -= n_batch;
 
                 LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", {
-                    {"i",   i},
-                    {"n_batch",  n_batch},
-                    {"ret",   ret},
+                    {"i",       i},
+                    {"n_batch", n_batch},
+                    {"ret",     ret},
                 });
 
                 continue; // continue loop of n_batch
             }
 
             for (auto & slot : slots) {
-                if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
+                if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
                     continue; // continue loop of slots
                 }
 
-                // prompt evaluated for embedding
-                if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
-                    send_embedding(slot, batch_view);
-                    slot.release();
-                    slot.i_batch = -1;
+                if (slot.state == SLOT_STATE_DONE_PROMPT) {
+                    if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
+                        // prompt evaluated for embedding
+                        send_embedding(slot, batch_view);
+                        slot.release();
+                        slot.i_batch = -1;
+                        continue; // continue loop of slots
+                    } else {
+                        // prompt evaluated for next-token prediction
+                        slot.state = SLOT_STATE_GENERATING;
+                    }
+                } else if (slot.state != SLOT_STATE_GENERATING) {
                     continue; // continue loop of slots
                 }
 
@@ -2425,6 +2426,7 @@ struct server_context {
                 }
 
                 if (!process_token(result, slot)) {
+                    // release slot because of stop condition
                     slot.release();
                     slot.print_timings();
                     send_final_response(slot);
@@ -2705,7 +2707,7 @@ int main(int argc, char ** argv) {
         task.type = SERVER_TASK_TYPE_METRICS;
 
         ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
+        ctx_server.queue_tasks.post(task, true); // high-priority task
 
         // get the result
         server_task_result result = ctx_server.queue_results.recv(task.id);
@@ -2737,7 +2739,7 @@ int main(int argc, char ** argv) {
         task.data.push_back({{"reset_bucket", true}});
 
         ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
+        ctx_server.queue_tasks.post(task, true); // high-priority task
 
         // get the result
         server_task_result result = ctx_server.queue_results.recv(task.id);
@@ -2751,6 +2753,9 @@ int main(int argc, char ** argv) {
         const uint64_t n_tokens_predicted  = data.at("n_tokens_predicted");
         const uint64_t t_tokens_generation = data.at("t_tokens_generation");
 
+        const uint64_t n_decode_total     = data.at("n_decode_total");
+        const uint64_t n_busy_slots_total = data.at("n_busy_slots_total");
+
         const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
 
         // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
@@ -2771,6 +2776,14 @@ int main(int argc, char ** argv) {
                     {"name",  "tokens_predicted_seconds_total"},
                     {"help",  "Predict process time"},
                     {"value",  (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
+            }, {
+                    {"name",  "n_decode_total"},
+                    {"help",  "Total number of llama_decode() calls"},
+                    {"value",  n_decode_total}
+            }, {
+                    {"name",  "n_busy_slots_per_decode"},
+                    {"help",  "Average number of busy slots per llama_decode() call"},
+                    {"value",  (float) n_busy_slots_total / (float) n_decode_total}
             }}},
             {"gauge", {{
                     {"name",  "prompt_tokens_seconds"},
@@ -2837,7 +2850,7 @@ int main(int argc, char ** argv) {
         task.data = {
             { "id_slot", id_slot },
             { "filename", filename },
-            { "filepath", filepath }
+            { "filepath", filepath },
         };
 
         const int id_task = ctx_server.queue_tasks.post(task);
@@ -2867,7 +2880,7 @@ int main(int argc, char ** argv) {
         task.data = {
             { "id_slot", id_slot },
             { "filename", filename },
-            { "filepath", filepath }
+            { "filepath", filepath },
         };
 
         const int id_task = ctx_server.queue_tasks.post(task);
@@ -2945,7 +2958,7 @@ int main(int argc, char ** argv) {
             { "system_prompt",               ctx_server.system_prompt.c_str() },
             { "default_generation_settings", ctx_server.default_generation_settings_for_props },
             { "total_slots",                 ctx_server.params.n_parallel },
-            { "chat_template",               curr_tmpl.c_str() }
+            { "chat_template",               curr_tmpl.c_str() },
         };
 
         res_ok(res, data);
@@ -3056,13 +3069,13 @@ int main(int argc, char ** argv) {
         json models = {
             {"object", "list"},
             {"data", {
-                 {
-                     {"id",       params.model_alias},
-                     {"object",   "model"},
-                     {"created",  std::time(0)},
-                     {"owned_by", "llamacpp"},
-                     {"meta",     ctx_server.model_meta()}
-                 },
+                {
+                    {"id",       params.model_alias},
+                    {"object",   "model"},
+                    {"created",  std::time(0)},
+                    {"owned_by", "llamacpp"},
+                    {"meta",     ctx_server.model_meta()}
+                },
              }}
         };
 
index 6cd306a2bcf7c9b5a7eb56550628c8a2e82e5f2c..423d0f1d42f550189255bfdca52a7dd9968a7ff1 100644 (file)
@@ -77,6 +77,35 @@ Feature: Parallel
       | disabled  | 128       |
       | enabled   | 64        |
 
+  Scenario Outline: Multi users with number of prompts exceeding number of slots
+    Given a system prompt You are a writer.
+    And   a model tinyllama-2
+    Given a prompt:
+      """
+      Write a very long book.
+      """
+    And a prompt:
+      """
+      Write another a poem.
+      """
+    And a prompt:
+      """
+      What is LLM?
+      """
+    And a prompt:
+      """
+      The sky is blue and I love it.
+      """
+    And <n_predict> max tokens to predict
+    And streaming is <streaming>
+    Given concurrent OAI completions requests
+    Then the server is busy
+    Then the server is idle
+    Then all prompts are predicted with <n_predict> tokens
+    Examples:
+      | streaming | n_predict |
+      | disabled  | 128       |
+      | enabled   | 64        |
 
   Scenario:  Multi users with total number of tokens to predict exceeds the KV Cache size #3969
     Given a prompt:
index 6a5a84e6a194177ea4b5ad17bfe803cb0ab02f39..ff0a82cc4658117bc8d8db76b4d185ebcf6cc83a 100644 (file)
@@ -15,6 +15,7 @@ Feature: Passkey / Self-extend with context shift
     And   <n_junk> as number of junk
     And   <n_predicted> server max tokens to predict
     And   42 as seed
+    And   0.0 temperature
     And   <n_ctx> KV cache size
     And   1 slots
     And   <n_ga> group attention factor to extend context size through self-extend
@@ -22,7 +23,8 @@ Feature: Passkey / Self-extend with context shift
     # Can be override with N_GPU_LAYERS
     And   <ngl> GPU offloaded layers
     Then  the server is starting
-    Then  the server is healthy
+    # Higher timeout because the model may need to be downloaded from the internet
+    Then  the server is healthy with timeout 120 seconds
     Given available models
     Then  model 0 is trained on <n_ctx_train> tokens context
     Given a prefix prompt:
index 18daad4760e701a20acb409f87756e1c4bc073e2..65b71a8e85db12430bdfad860bca4a401073c5b0 100644 (file)
@@ -202,17 +202,15 @@ def step_start_server(context):
             time.sleep(0.1)
 
 
-@step("the server is {expecting_status}")
-@async_run_until_complete
-async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
+async def wait_for_server_status_with_timeout(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str, timeout: int):
     match expecting_status:
         case 'healthy':
             await wait_for_slots_status(context, context.base_url, 200,
-                                        timeout=30)
+                                        timeout=timeout)
 
         case 'ready' | 'idle':
             await wait_for_slots_status(context, context.base_url, 200,
-                                        timeout=30,
+                                        timeout=timeout,
                                         params={'fail_on_no_slot': 1},
                                         slots_idle=context.n_slots,
                                         slots_processing=0)
@@ -225,6 +223,18 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status: Lite
             assert False, "unknown status"
 
 
+@step("the server is {expecting_status} with timeout {timeout:d} seconds")
+@async_run_until_complete
+async def step_wait_for_server_status_with_timeout(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str, timeout: int):
+    await wait_for_server_status_with_timeout(context, expecting_status, timeout)
+
+
+@step("the server is {expecting_status}")
+@async_run_until_complete
+async def step_wait_for_server_status(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
+    await wait_for_server_status_with_timeout(context, expecting_status, 30)
+
+
 @step('all slots are {expected_slot_status_string}')
 @async_run_until_complete
 async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):