]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama: consistent ctx <-> buf order for KV cache (#16746)
authorJohannes Gäßler <redacted>
Tue, 28 Oct 2025 10:23:54 +0000 (11:23 +0100)
committerGitHub <redacted>
Tue, 28 Oct 2025 10:23:54 +0000 (11:23 +0100)
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-memory-recurrent.cpp
src/llama-memory-recurrent.h
src/llama-model.cpp

index 736693e174527d76aeaabd7f27855d72e033ae66..add74391f0c47bb7ed6cd4179526f5ecb956db80 100644 (file)
@@ -8,6 +8,7 @@
 #include <algorithm>
 #include <cassert>
 #include <cmath>
+#include <cstring>
 #include <limits>
 #include <map>
 #include <stdexcept>
@@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache(
 
     const uint32_t n_layer_kv = hparams.n_layer_kv();
 
+    // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
+    struct ggml_backend_buft_comparator {
+        bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
+            return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
+        }
+    };
+    std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
+
     // create a context for each buffer type
-    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
     auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
@@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache(
                 return nullptr;
             }
 
-            ctx_map[buft] = ctx;
-            ctxs.emplace_back(ctx);
+            ctx_map.emplace(buft, ctx);
 
             return ctx;
         }
 
-        return it->second;
+        return it->second.get();
     };
 
     GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max);
@@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache(
     }
 
     // allocate tensors and initialize the buffers to avoid NaNs in the padding
-    for (auto it : ctx_map) {
-        auto * buft = it.first;
-        auto * ctx  = it.second;
-
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+    for (auto & [buft, ctx] : ctx_map) {
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
         if (!buf) {
             throw std::runtime_error("failed to allocate buffer for kv cache");
         }
@@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache(
         LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
 
         ggml_backend_buffer_clear(buf, 0);
-        bufs.emplace_back(buf);
+        ctxs_bufs.emplace_back(std::move(ctx), buf);
     }
 
     {
@@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) {
     }
 
     if (data) {
-        for (auto & buf : bufs) {
+        for (auto & [_, buf] : ctxs_bufs) {
             ggml_backend_buffer_clear(buf.get(), 0);
         }
     }
@@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
 
 std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
     std::map<ggml_backend_buffer_type_t, size_t> ret;
-    for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
-        ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
+    for (const auto & [_, buf] : ctxs_bufs) {
+        ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
     }
     return ret;
 }
@@ -1298,7 +1302,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
 size_t llama_kv_cache::total_size() const {
     size_t size = 0;
 
-    for (const auto & buf : bufs) {
+    for (const auto & [_, buf] : ctxs_bufs) {
         size += ggml_backend_buffer_get_size(buf.get());
     }
 
index 85f0663d8c1d4247b9712100372bf8701cf22df4..150e282596255d5e8121bf4b781523d9b59dbf79 100644 (file)
@@ -217,8 +217,8 @@ private:
     // this is the SWA type of the cache - not to be confused with the model SWA type
     const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
 
-    std::vector<ggml_context_ptr>        ctxs;
-    std::vector<ggml_backend_buffer_ptr> bufs;
+    // ggml contexts for the KV cache along with the allocated backend buffers:
+    std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
 
     // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
     // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
index d67f5a5f47b87c2440a2a20c281ac71692498e81..276e1697d466c6dd738c3335265735db7e0a8331 100644 (file)
@@ -7,6 +7,7 @@
 
 #include <algorithm>
 #include <cassert>
+#include <cstring>
 #include <limits>
 #include <map>
 #include <stdexcept>
@@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent(
     cells.clear();
     cells.resize(mem_size);
 
+    // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
+    struct ggml_backend_buft_comparator {
+        bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
+            return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
+        }
+    };
+    std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;
+
     // create a context for each buffer type
-    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
     auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
         auto it = ctx_map.find(buft);
         if (it == ctx_map.end()) {
@@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent(
                 return nullptr;
             }
 
-            ctx_map[buft] = ctx;
-            ctxs.emplace_back(ctx);
+            ctx_map.emplace(buft, ctx);
 
             return ctx;
         }
 
-        return it->second;
+        return it->second.get();
     };
 
     r_l.resize(n_layer);
@@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent(
     }
 
     // allocate tensors and initialize the buffers to avoid NaNs in the padding
-    for (auto it : ctx_map) {
-        auto * buft = it.first;
-        auto * ctx  = it.second;
-
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+    for (auto & [buft, ctx] : ctx_map) {
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft);
         if (!buf) {
             throw std::runtime_error("failed to allocate buffer for rs cache");
         }
         ggml_backend_buffer_clear(buf, 0);
         LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
-        bufs.emplace_back(buf);
+        ctxs_bufs.emplace_back(std::move(ctx), buf);
     }
 
     {
@@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) {
     used = 0;
 
     if (data) {
-        for (auto & buf : bufs) {
+        for (auto & [_, buf] : ctxs_bufs) {
             ggml_backend_buffer_clear(buf.get(), 0);
         }
     }
@@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
 
 std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
     std::map<ggml_backend_buffer_type_t, size_t> ret;
-    for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
-        ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
+    for (const auto & [_, buf] : ctxs_bufs) {
+        ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get());
     }
     return ret;
 }
@@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const {
 
 size_t llama_memory_recurrent::total_size() const {
     size_t size = 0;
-    for (const auto & buf : bufs) {
+    for (const auto & [_, buf] : ctxs_bufs) {
         size += ggml_backend_buffer_get_size(buf.get());
     }
 
index 077c6e3ce938da5f6dfd8d7d63f842dec2989fca..47f01d7391248a7b00d095d7d94aca785394ee1f 100644 (file)
@@ -109,8 +109,8 @@ private:
 
     const uint32_t n_seq_max = 1;
 
-    std::vector<ggml_context_ptr>        ctxs;
-    std::vector<ggml_backend_buffer_ptr> bufs;
+    // ggml contexts for the KV cache along with the allocated backend buffers:
+    std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
 
     size_t total_size() const;
 
index 05e467180089e800226556c957605369d34e519c..bb83a04e9605531d2dd0ffa36ce942d7d1a60fbf 100644 (file)
@@ -2231,7 +2231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
     // define a comparator for the buft -> ctx map to ensure that the order is well-defined:
     struct ggml_backend_buft_comparator {
         bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const {
-            return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs);
+            return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0;
         }
     };
     std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map;