]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : refactor kv cache guard (#12695)
authorGeorgi Gerganov <redacted>
Wed, 2 Apr 2025 11:32:59 +0000 (14:32 +0300)
committerGitHub <redacted>
Wed, 2 Apr 2025 11:32:59 +0000 (14:32 +0300)
* llama : refactor kv cache guard

ggml-ci

* cont : fix comment [no ci]

* llama : fix kv_cache restore logic

ggml-ci

* context : simplify kv cache updates

ggml-ci

* cont : better name [no ci]

* llama : fix llama_decode return code when could not find KV slot

ggml-ci

* context : change log err -> warn [no ci]

* kv-cache : add comment + warning

examples/parallel/parallel.cpp
src/llama-context.cpp
src/llama-kv-cache.cpp
src/llama-kv-cache.h

index e0e6da631dad367d691e83cb69e245160fad8f7d..80698518e310286ce15ce8c7338ebc857d5149da 100644 (file)
@@ -106,6 +106,8 @@ int main(int argc, char ** argv) {
 
     common_params params;
 
+    params.n_predict = 128;
+
     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
         return 1;
     }
index 3479a8cca3d6408e9a354b726806dd4854c5fb0e..7d067afbe73995425fd47a636e5d250bbfa15153 100644 (file)
@@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     const int64_t n_tokens_all = batch.n_tokens;
     const int64_t n_embd       = hparams.n_embd;
 
-    // TODO: remove this stuff
-    class batch_guard {
-    public:
-        batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
-        }
-
-        ~batch_guard() {
-            if (!is_done) {
-                kv_slot_restorer.restore();
-            }
-        }
-
-        void done() {
-            is_done = true;
-        }
-
-        void save(const llama_kv_cache_slot_info & slot_info) {
-            kv_slot_restorer.save(slot_info);
-        }
-
-    private:
-        bool is_done = false;
-
-        llama_kv_slot_restorer kv_slot_restorer;
-    };
-
-    batch_guard bg(*kv_self);
+    llama_kv_cache_guard kv_guard(kv_self.get());
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
@@ -1280,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) {
         return -2;
     };
 
+    // handle any pending defrags/shifts
+    kv_self_update();
+
     int64_t n_outputs_prev = 0;
 
     while (sbatch.n_tokens > 0) {
@@ -1319,22 +1296,12 @@ int llama_context::decode(llama_batch & inp_batch) {
 
         // find KV slot
         {
-            kv_self_update();
+            if (!kv_self->find_slot(ubatch)) {
+                LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
 
-            // if we have enough unused cells before the current head ->
-            //   better to start searching from the beginning of the cache, hoping to fill it
-            if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
-                kv_self->head = 0;
+                return 1;
             }
 
-            const auto slot_info = kv_self->find_slot(ubatch);
-            if (!slot_info) {
-                LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
-                return -3;
-            }
-
-            bg.save(slot_info);
-
             if (!kv_self->recurrent) {
                 // a heuristic, to avoid attending the full cache if it is not yet utilized
                 // after enough generations, the benefit from this heuristic disappears
@@ -1371,16 +1338,6 @@ int llama_context::decode(llama_batch & inp_batch) {
             }
         }
 
-        // update the kv ring buffer
-        {
-            kv_self->head += ubatch.n_tokens;
-
-            // Ensure kv cache head points to a valid index.
-            if (kv_self->head >= kv_self->size) {
-                kv_self->head = 0;
-            }
-        }
-
         // plot the computation graph in dot format (for debugging purposes)
         //if (n_past%100 == 0) {
         //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
@@ -1467,7 +1424,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     }
 
     // finalize the batch processing
-    bg.done();
+    kv_guard.commit();
 
     // set output mappings
     {
index 14c8933b4d6c4f65d0ad0818b2443ad1c676e249..7ba546c10ff74f8035a7e9398720c3bb77574fd7 100644 (file)
@@ -11,8 +11,6 @@
 #include <map>
 #include <stdexcept>
 
-static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
-
 llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
 }
 
@@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
                 return false;
             }
         }
+
+        return true;
     }
 
     for (uint32_t i = 0; i < size; ++i) {
@@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() {
     }
 }
 
+void llama_kv_cache_unified::restore() {
+    if (pending.ranges.empty()) {
+        return;
+    }
+
+    // TODO: tmp - move to llama_kv_cache_recurrent
+    if (recurrent) {
+        seq_rm(-1, -1, -1);
+        return;
+    }
+
+    uint32_t new_head = size;
+
+    for (auto & range : pending.ranges) {
+        for (uint32_t i = range.c0; i < range.c1; ++i) {
+            cells[i].seq_id.clear();
+
+            // keep count of the number of used cells
+            if (cells[i].pos >= 0) {
+                used--;
+            }
+
+            cells[i].pos = -1;
+            cells[i].src = -1;
+        }
+
+        new_head = std::min(new_head, range.c0);
+    }
+
+    if (new_head != size && new_head < head) {
+        head = new_head;
+    }
+}
+
+void llama_kv_cache_unified::commit() {
+    if (pending.ranges.empty()) {
+        LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
+                __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
+        return;
+    }
+
+    pending.ranges.clear();
+}
+
 bool llama_kv_cache_unified::get_can_shift() const {
     return can_shift;
 }
 
-llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
+bool llama_kv_cache_unified::find_slot(
        const llama_ubatch & ubatch) {
     const uint32_t n_tokens = ubatch.n_tokens;
     const uint32_t n_seqs   = ubatch.n_seqs;
     const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
 
+    // if we have enough unused cells before the current head ->
+    //   better to start searching from the beginning of the cache, hoping to fill it
+    if (head > used + 2*ubatch.n_tokens) {
+        head = 0;
+    }
+
     if (recurrent) {
         // For recurrent state architectures (like Mamba or RWKV),
         // each cache cell can store the state for a whole sequence.
@@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
                     // too big seq_id
                     // TODO: would it be possible to resize the cache instead?
                     LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
-                    return llama_kv_cache_slot_info_failed;
+                    return false;
                 }
                 if (j > 0) {
                     llama_kv_cell & seq = cells[seq_id];
@@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
             [](const llama_kv_cell& cell){ return !cell.is_empty(); });
 
         // sanity check
-        return llama_kv_cache_slot_info(n >= n_seqs);
+        return n >= n_seqs;
     }
 
     // otherwise, one cell per token.
 
     if (n_tokens > size) {
         LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
-        return llama_kv_cache_slot_info_failed;
+        return false;
     }
 
     uint32_t n_tested = 0;
@@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
 
         if (n_tested >= size) {
             //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
-            return llama_kv_cache_slot_info_failed;
+            return false;
         }
     }
 
@@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
 
     used += n_tokens;
 
-    return llama_kv_cache_slot_info(head, head + n_tokens);
+    pending.ranges.push_back({head, head + n_tokens});
+
+    return true;
 }
 
 uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
@@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             return false;
         }
+        commit();
 
         // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
         // Assume that this is one contiguous block of cells
index 0a7ff8a4ea3e685589215932eec0534aa4959f19..ff0ba3540d6e2e62b3d1fb5b08e7e6f72ce1b21c 100644 (file)
@@ -17,6 +17,9 @@ struct llama_ubatch;
 struct llama_kv_cache : public llama_memory_i {
     using llama_memory_i::llama_memory_i;
 
+    virtual void restore() = 0; // call if batch processing fails - restores the cache state
+    virtual void commit() = 0;  // call after successful batch processing - clears any pending state
+
     virtual int32_t  get_n_tokens()   const = 0;
     virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
 
@@ -25,9 +28,24 @@ struct llama_kv_cache : public llama_memory_i {
     bool get_can_edit() const override { return get_can_shift(); }
 };
 
+struct llama_kv_cache_guard {
+    llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
+
+    ~llama_kv_cache_guard() {
+        kv->restore();
+    }
+
+    void commit() {
+        kv->commit();
+    }
+
+private:
+    llama_kv_cache * kv;
+};
+
 struct llama_kv_cell {
     llama_pos pos   = -1;
-    llama_pos delta = 0;
+    llama_pos delta =  0;
     int32_t   src   = -1; // used by recurrent state models to copy states
     int32_t   tail  = -1;
 
@@ -46,17 +64,6 @@ struct llama_kv_cell {
     }
 };
 
-// a structure holds information about the slot found in llama_kv_cache_find_slot
-struct llama_kv_cache_slot_info {
-    std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
-    bool found = false;                       // the slot was found
-
-    explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
-    llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
-
-    operator bool() const { return found; }
-};
-
 // ring-buffer of cached KV data
 // TODO: pimpl
 // TODO: add notion of max sequences
@@ -93,6 +100,9 @@ public:
     void clear() override;
     void defrag() override;
 
+    virtual void restore() override;
+    virtual void commit() override;
+
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
     void seq_keep(llama_seq_id seq_id) override;
@@ -105,10 +115,9 @@ public:
 
     // find an empty slot of size "n_tokens" in the cache
     // updates the cache head
-    // returns a structure holding information about the slot found
     // Note: On success, it's important that cache.head points
     // to the first cell of the slot.
-    llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
+    bool find_slot(const llama_ubatch & batch);
 
     // TODO: maybe not needed
     uint32_t get_padding(const llama_cparams & cparams) const;
@@ -128,7 +137,19 @@ public:
     // return true if cells have been moved
     bool defrag_prepare(int32_t n_max_nodes);
 
-    // state save/load
+    // commit/restore cache
+
+    struct slot_range {
+        uint32_t c0 = 0; // note: these are cell indices, not sequence positions
+        uint32_t c1 = 0;
+    };
+
+    // pending cell updates that are not yet committed
+    struct {
+        std::vector<slot_range> ranges;
+    } pending;
+
+    // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
     void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1);
@@ -183,59 +204,6 @@ private:
 //    using llama_kv_cache_unified::llama_kv_cache_unified;
 //};
 
-//
-// kv cache restore
-//
-
-// saves the kv_cache state for future recovery.
-// used to rollback llama_kv_cache_find_slot changes.
-struct llama_kv_slot_restorer {
-    struct llama_kv_cache_state {
-        uint32_t head = 0;
-        uint32_t n    = 0;
-    } old_state;
-
-    // for non-recurrent models only
-    // list of slots to restore
-    std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
-
-    bool do_restore = false;
-
-    llama_kv_cache_unified & cache;
-
-    explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
-        old_state.head = cache.head;
-        old_state.n    = cache.n;
-    }
-
-    // saves a slot information for future restoration
-    void save(const llama_kv_cache_slot_info & slot) {
-        if (slot) {
-            do_restore = true;
-            if (slot.boundaries.first != slot.boundaries.second) {
-                slot_boundaries.push_back(slot.boundaries);
-            }
-        }
-    }
-
-    // must be explicitly called to restore the kv_cache state
-    // and rollback changes from all llama_kv_cache_find_slot calls
-    void restore() {
-        if (do_restore) {
-            cache.head = old_state.head;
-            cache.n    = old_state.n;
-
-            if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
-                cache.seq_rm(-1, -1, -1);
-            } else {
-                for (auto & slot : slot_boundaries) {
-                    cache.seq_rm(-1, slot.first, slot.second);
-                }
-            }
-        }
-    }
-};
-
 // TODO: maybe become part of the public llama_kv_cache in the future
 int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);