]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : auto-batch preparation (#13845)
authorGeorgi Gerganov <redacted>
Sat, 31 May 2025 09:55:57 +0000 (12:55 +0300)
committerGitHub <redacted>
Sat, 31 May 2025 09:55:57 +0000 (12:55 +0300)
* llama : auto-batch

ggml-ci

* context : simplify if branching

examples/parallel/parallel.cpp
src/llama-context.cpp
src/llama-context.h
src/llama-kv-cache.cpp
tools/server/server.cpp

index 22118faf8c20d73ab7198e6c25eb51b52e423eac..d7b269df0dea22fef787bb27cd8c057904b4c4b3 100644 (file)
@@ -392,7 +392,7 @@ int main(int argc, char ** argv) {
                     return 1;
                 }
 
-                LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
+                LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
 
                 n_cache_miss += 1;
 
index 808fe5991088ce7e1e9db6a144ee810fb75621c0..57c7b42269798baf9966e3cfa2c44755cb9a4e92 100644 (file)
@@ -424,28 +424,33 @@ const llama_kv_cache * llama_context::get_kv_self() const {
     return kv_self;
 }
 
-void llama_context::kv_self_update() {
+bool llama_context::kv_self_update() {
     if (!memory) {
-        return;
+        return false;
     }
 
     llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
-    if (kv_self->update(*this)) {
-        // if the KV cache did any computation, we have to reserve a new worst-case graph
-        const auto kv_state = kv_self->init_full();
-        if (!kv_state) {
-            throw std::runtime_error("failed to initialize KV cache");
-        }
+    if (!kv_self->update(*this)) {
+        // no updates have been performed
+        return false;
+    }
 
-        const uint32_t n_seqs   = cparams.n_seq_max;
-        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+    // if the KV cache did any computation, we have to reserve a new worst-case graph
+    const auto kv_state = kv_self->init_full();
+    if (!kv_state) {
+        throw std::runtime_error("failed to initialize KV cache");
+    }
 
-        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
-        if (!gf) {
-            LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
-        }
+    const uint32_t n_seqs   = cparams.n_seq_max;
+    const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+
+    auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+    if (!gf) {
+        LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
     }
+
+    return true;
 }
 
 enum llama_pooling_type llama_context::pooling_type() const {
@@ -933,24 +938,44 @@ int llama_context::decode(llama_batch & inp_batch) {
     // handle any pending defrags/shifts
     kv_self_update();
 
-    auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
-    if (!kv_state) {
-        return -2;
-    }
+    llama_memory_state_ptr kv_state;
 
-    switch (kv_state->get_status()) {
-        case LLAMA_MEMORY_STATUS_SUCCESS:
-            {
-            } break;
-        case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
-            {
-                // not a fatal error, we can re-try with a different batch
-                return 1;
-            }
-        case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
-            {
-                return -2;
-            }
+    bool did_defrag = false;
+
+    while (true) {
+        kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
+        if (!kv_state) {
+            return -2;
+        }
+
+        switch (kv_state->get_status()) {
+            case LLAMA_MEMORY_STATUS_SUCCESS:
+                {
+                } break;
+            case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+                {
+                    if (!did_defrag) {
+                        did_defrag = true;
+
+                        kv_self->defrag_sched(-1.0f);
+                        if (kv_self_update()) {
+                            LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
+
+                            continue;
+                        }
+                    }
+
+                    LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
+
+                    return 1;
+                }
+            case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+                {
+                    return -2;
+                }
+        }
+
+        break;
     }
 
     // reserve output buffer
@@ -2646,22 +2671,8 @@ int32_t llama_encode(
 int32_t llama_decode(
         llama_context * ctx,
           llama_batch   batch) {
-    int ret = ctx->decode(batch);
-
-    // defrag and try again
-    // TODO: distinguish return code when we are sure that even after defrag there is no space available
-    if (ret == 1) {
-        llama_kv_self_defrag(ctx);
-        ret = ctx->decode(batch);
-
-        if (ret == 1) {
-            LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
-
-            return ret;
-        }
-    }
-
-    if (ret != 0) {
+    const int ret = ctx->decode(batch);
+    if (ret != 0 && ret != 1) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
 
index 5b79bafa75db755e56920b8b2940e7ef0f1da6ce..3b880286bfd5de755467aea8eef849a3f13df766 100644 (file)
@@ -50,8 +50,9 @@ struct llama_context {
           llama_kv_cache * get_kv_self();
     const llama_kv_cache * get_kv_self() const;
 
+    // return true of the KV cache was updated
     // TODO: remove
-    void kv_self_update();
+    bool kv_self_update();
 
     enum llama_pooling_type pooling_type() const;
 
index 86c4f2816f8097282fd1e220e88a2db5470c538b..4726b700ff926cf5302cd118dd0be1ff3919b362 100644 (file)
@@ -1809,9 +1809,10 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
 llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
     GGML_UNUSED(embd_pooled);
 
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
+    // TODO: if we fail with split_simple, we should attempt different splitting strategies
+    //       but to do that properly, we first have to refactor the batches to be more flexible
 
-    // TODO: if we fail with split_simple, we should attempt split_equal
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
 
     std::vector<llama_ubatch> ubatches;
 
index 90981ff9a5ef75835f0d284158bbd8f4e004d9bc..46dbe5cc3951df7a61139919af21dec19ab1905a 100644 (file)
@@ -3431,7 +3431,7 @@ struct server_context {
                 // retry with half the batch size to try to find a free slot in the KV cache
                 n_batch /= 2;
 
-                SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
+                SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
 
                 continue; // continue loop of n_batch
             }