]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : propagate the results of `graph_compute` (#9525)
authorMichael Podvitskiy <redacted>
Wed, 13 Nov 2024 18:00:35 +0000 (20:00 +0200)
committerGitHub <redacted>
Wed, 13 Nov 2024 18:00:35 +0000 (20:00 +0200)
* llama: propagating the results of `graph_compute` to the user interface

* llama: reverting kv_cache in case of failed compute

* llama: `llama_kv_cache_state` was removed, only the result of `llama_graph_compute` is returned

* llama: restore a kv_cache in case of failed computation

* llama: correct reverting of the entire batch.
also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models

* llama: updated comments

* llama : add comments about KV cache state after error

---------

Co-authored-by: Georgi Gerganov <redacted>
include/llama.h
src/llama.cpp

index ccb48f73cef5cb7ceed6cac15b23419751e9e963..5e742642eec8de2ebdcbb3e756290829ff2623b5 100644 (file)
@@ -797,7 +797,7 @@ extern "C" {
     // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
     // Stores the encoder output internally for later use by the decoder cross-attention layers.
     //   0 - success
-    // < 0 - error
+    // < 0 - error. the KV cache state is restored to the state before this call
     LLAMA_API int32_t llama_encode(
             struct llama_context * ctx,
               struct llama_batch   batch);
@@ -805,7 +805,7 @@ extern "C" {
     // Positive return values does not mean a fatal error, but rather a warning.
     //   0 - success
     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
-    // < 0 - error
+    // < 0 - error. the KV cache state is restored to the state before this call
     LLAMA_API int32_t llama_decode(
             struct llama_context * ctx,
               struct llama_batch   batch);
index 4d89c522257c5ff8457e1d76eadd3964697cc2e4..97eee26a577bccd357414c5c8ae6713f87c4c174 100644 (file)
@@ -3502,11 +3502,24 @@ static bool llama_kv_cache_init(
     return true;
 }
 
+// a structure holds information about the slot found in llama_kv_cache_find_slot
+struct llama_kv_cache_slot_info {
+    std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
+    bool found = false;                       // the slot was found
+
+    explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
+    llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
+
+    operator bool() const { return found; }
+};
+static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
+
 // find an empty slot of size "n_tokens" in the cache
 // updates the cache head
+// returns a structure holding information about the slot found
 // Note: On success, it's important that cache.head points
 // to the first cell of the slot.
-static bool llama_kv_cache_find_slot(
+static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
            struct llama_kv_cache & cache,
        const struct llama_ubatch & batch) {
     const uint32_t n_tokens = batch.n_tokens;
@@ -3534,7 +3547,7 @@ static bool llama_kv_cache_find_slot(
                     // too big seq_id
                     // TODO: would it be possible to resize the cache instead?
                     LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
-                    return false;
+                    return llama_kv_cache_slot_info_failed;
                 }
                 if (j > 0) {
                     llama_kv_cell & seq = cache.cells[seq_id];
@@ -3669,15 +3682,17 @@ static bool llama_kv_cache_find_slot(
         // allow getting the range of used cells, from head to head + n
         cache.head = min;
         cache.n    = max - min + 1;
+        cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
+            [](const llama_kv_cell& cell){ return !cell.is_empty(); });
 
         // sanity check
-        return cache.n >= n_seqs;
+        return llama_kv_cache_slot_info(cache.n >= n_seqs);
     }
     // otherwise, one cell per token.
 
     if (n_tokens > cache.size) {
         LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
-        return false;
+        return llama_kv_cache_slot_info_failed;
     }
 
     uint32_t n_tested = 0;
@@ -3705,7 +3720,7 @@ static bool llama_kv_cache_find_slot(
 
         if (n_tested >= cache.size) {
             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
-            return false;
+            return llama_kv_cache_slot_info_failed;
         }
     }
 
@@ -3722,7 +3737,7 @@ static bool llama_kv_cache_find_slot(
 
     cache.used += n_tokens;
 
-    return true;
+    return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
 }
 
 // find how many cells are currently in use
@@ -3998,6 +4013,53 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
     return cparams.flash_attn ? 256u : 32u;
 }
 
+// saves the kv_cache state for future recovery.
+// used to rollback llama_kv_cache_find_slot changes.
+struct llama_kv_slot_restorer {
+    struct llama_kv_cache_state {
+        uint32_t head = 0;
+        uint32_t n    = 0;
+    } old_state;
+
+    // for non-recurrent models only
+    // list of slots to restore
+    std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
+
+    bool do_restore = false;
+
+    explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
+        old_state.head  = cache.head;
+        old_state.n     = cache.n;
+    }
+
+    // saves a slot information for future restoration
+    void save(const struct llama_kv_cache_slot_info & slot) {
+        if (slot) {
+            do_restore = true;
+            if (slot.boundaries.first != slot.boundaries.second) {
+                slot_boundaries.push_back(slot.boundaries);
+            }
+        }
+    }
+
+    // must be explicitly called to restore the kv_cache state
+    // and rollback changes from all llama_kv_cache_find_slot calls
+    void restore(struct llama_kv_cache & cache) {
+        if (do_restore) {
+            cache.head  = old_state.head;
+            cache.n     = old_state.n;
+
+            if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
+                llama_kv_cache_seq_rm(cache, -1, -1, -1);
+            } else {
+                for (auto & slot : slot_boundaries) {
+                    llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
+                }
+            }
+        }
+    }
+};
+
 //
 // model loading and saving
 //
@@ -17181,7 +17243,8 @@ static void llama_output_reorder(struct llama_context * ctx) {
     }
 }
 
-static void llama_graph_compute(
+// returns the result of ggml_backend_sched_graph_compute_async execution
+static enum ggml_status llama_graph_compute(
           llama_context & lctx,
             ggml_cgraph * gf,
                     int   n_threads,
@@ -17196,15 +17259,20 @@ static void llama_graph_compute(
         set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
     }
 
-    auto err = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
-    if (err != GGML_STATUS_SUCCESS) {
-        LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err);
+    auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
+    if (status != GGML_STATUS_SUCCESS) {
+        LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
     }
 
     // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
+
+    return status;
 }
 
 // decode a batch of tokens by evaluating the transformer
+// in case of unsuccessful decoding (error or warning),
+// the kv_cache state will be returned to its original state
+// (for non-recurrent models) or cleaned (for recurrent models)
 //
 //   - lctx:      llama context
 //   - batch:     batch to evaluate
@@ -17254,6 +17322,7 @@ static int llama_decode_internal(
     lctx.n_queued_tokens += n_tokens_all;
 
     auto & kv_self = lctx.kv_self;
+    llama_kv_slot_restorer kv_slot_restorer(kv_self);
 
     const int64_t n_embd  = hparams.n_embd;
     const int64_t n_vocab = hparams.n_vocab;
@@ -17338,9 +17407,11 @@ static int llama_decode_internal(
                 kv_self.head = 0;
             }
 
-            if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
+            const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+            if (!slot) {
                 return 1;
             }
+            kv_slot_restorer.save(slot);
 
             if (!kv_self.recurrent) {
                 // a heuristic, to avoid attending the full cache if it is not yet utilized
@@ -17387,7 +17458,19 @@ static int llama_decode_internal(
 
         llama_set_inputs(lctx, ubatch);
 
-        llama_graph_compute(lctx, gf, n_threads, threadpool);
+        const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
+        if (compute_status != GGML_STATUS_SUCCESS) {
+            kv_slot_restorer.restore(kv_self);
+            switch (compute_status) {
+                case GGML_STATUS_ABORTED:
+                    return 2;
+                case GGML_STATUS_ALLOC_FAILED:
+                    return -2;
+                case GGML_STATUS_FAILED:
+                default:
+                    return -3;
+            }
+        }
 
         // update the kv ring buffer
         {
@@ -17624,7 +17707,18 @@ static int llama_encode_internal(
 
     llama_set_inputs(lctx, ubatch);
 
-    llama_graph_compute(lctx, gf, n_threads, threadpool);
+    const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
+    switch (compute_status) {
+        case GGML_STATUS_SUCCESS:
+            break;
+        case GGML_STATUS_ABORTED:
+            return 2;
+        case GGML_STATUS_ALLOC_FAILED:
+            return -2;
+        case GGML_STATUS_FAILED:
+        default:
+            return -3;
+    }
 
     // extract embeddings
     if (embd) {