]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : simplify + fix warning for recurrent models (#12756)
authorGeorgi Gerganov <redacted>
Fri, 4 Apr 2025 18:48:10 +0000 (21:48 +0300)
committerGitHub <redacted>
Fri, 4 Apr 2025 18:48:10 +0000 (21:48 +0300)
ggml-ci

src/llama-context.cpp
src/llama-kv-cache.cpp
src/llama-kv-cache.h
src/llama-memory.h

index 3927079432d941a4bf7c601bb8bc396c09d7688a..4735e98ea040ffe3abeec306fc14625059304880 100644 (file)
@@ -2474,7 +2474,12 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
 }
 
 int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
-    return llama_kv_cache_n_tokens(ctx->get_kv_self());
+    const auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return 0;
+    }
+
+    return kv->get_n_tokens();
 }
 
 // deprecated
@@ -2483,7 +2488,12 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
 }
 
 int32_t llama_kv_self_used_cells(const llama_context * ctx) {
-    return llama_kv_cache_used_cells(ctx->get_kv_self());
+    const auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return 0;
+    }
+
+    return kv->get_used_cells();
 }
 
 // deprecated
@@ -2492,7 +2502,12 @@ void llama_kv_cache_clear(llama_context * ctx) {
 }
 
 void llama_kv_self_clear(llama_context * ctx) {
-    llama_kv_cache_clear(ctx->get_kv_self());
+    auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return;
+    }
+
+    kv->clear();
 }
 
 // deprecated
@@ -2509,7 +2524,12 @@ bool llama_kv_self_seq_rm(
          llama_seq_id   seq_id,
             llama_pos   p0,
             llama_pos   p1) {
-    return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1);
+    auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return true;
+    }
+
+    return kv->seq_rm(seq_id, p0, p1);
 }
 
 // deprecated
@@ -2528,7 +2548,12 @@ void llama_kv_self_seq_cp(
          llama_seq_id   seq_id_dst,
             llama_pos   p0,
             llama_pos   p1) {
-    return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1);
+    auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return;
+    }
+
+    return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
 }
 
 // deprecated
@@ -2539,7 +2564,12 @@ void llama_kv_cache_seq_keep(
 }
 
 void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
-    return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id);
+    auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return;
+    }
+
+    return kv->seq_keep(seq_id);
 }
 
 // deprecated
@@ -2558,7 +2588,12 @@ void llama_kv_self_seq_add(
             llama_pos   p0,
             llama_pos   p1,
             llama_pos   delta) {
-    return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta);
+    auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return;
+    }
+
+    return kv->seq_add(seq_id, p0, p1, delta);
 }
 
 // deprecated
@@ -2577,7 +2612,12 @@ void llama_kv_self_seq_div(
             llama_pos   p0,
             llama_pos   p1,
                   int   d) {
-    return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d);
+    auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return;
+    }
+
+    return kv->seq_div(seq_id, p0, p1, d);
 }
 
 // deprecated
@@ -2586,7 +2626,12 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
 }
 
 llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
-    return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id);
+    const auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return 0;
+    }
+
+    return kv->seq_pos_max(seq_id);
 }
 
 // deprecated
@@ -2595,7 +2640,12 @@ void llama_kv_cache_defrag(llama_context * ctx) {
 }
 
 void llama_kv_self_defrag(llama_context * ctx) {
-    llama_kv_cache_defrag(ctx->get_kv_self());
+    auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return;
+    }
+
+    return kv->defrag();
 }
 
 // deprecated
@@ -2604,7 +2654,12 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) {
 }
 
 bool llama_kv_self_can_shift(const llama_context * ctx) {
-    return llama_kv_cache_can_shift(ctx->get_kv_self());
+    const auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return false;
+    }
+
+    return kv->get_can_shift();
 }
 
 // deprecated
index 7ba546c10ff74f8035a7e9398720c3bb77574fd7..dbf5f1187d9e557c0bc7640b159e4326a6035191 100644 (file)
@@ -131,7 +131,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const {
     return result;
 }
 
-uint32_t llama_kv_cache_unified::get_used_cells() const {
+int32_t llama_kv_cache_unified::get_used_cells() const {
     return used;
 }
 
@@ -428,7 +428,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
     }
 }
 
-llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
+llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
     llama_pos result = 0;
 
     for (uint32_t i = 0; i < size; ++i) {
@@ -481,6 +481,11 @@ void llama_kv_cache_unified::restore() {
 }
 
 void llama_kv_cache_unified::commit() {
+    // TODO: tmp - move to llama_kv_cache_recurrent
+    if (recurrent) {
+        return;
+    }
+
     if (pending.ranges.empty()) {
         LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
                 __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
@@ -1273,117 +1278,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
     return true;
 }
 
-//
-// interface implementation
-//
-
-int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
-    if (!kv) {
-        return 0;
-    }
-
-    return kv->get_n_tokens();
-}
-
-int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
-    if (!kv) {
-        return 0;
-    }
-
-    return kv->get_used_cells();
-}
-
-void llama_kv_cache_clear(llama_kv_cache * kv) {
-    if (!kv) {
-        return;
-    }
-
-    kv->clear();
-}
-
-bool llama_kv_cache_seq_rm(
-        llama_kv_cache * kv,
-          llama_seq_id   seq_id,
-             llama_pos   p0,
-             llama_pos   p1) {
-    if (!kv) {
-        return true;
-    }
-
-    return kv->seq_rm(seq_id, p0, p1);
-}
-
-void llama_kv_cache_seq_cp(
-        llama_kv_cache * kv,
-          llama_seq_id   seq_id_src,
-          llama_seq_id   seq_id_dst,
-             llama_pos   p0,
-             llama_pos   p1) {
-    if (!kv) {
-        return;
-    }
-
-    kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
-}
-
-void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
-    if (!kv) {
-        return;
-    }
-
-    kv->seq_keep(seq_id);
-}
-
-void llama_kv_cache_seq_add(
-        llama_kv_cache * kv,
-          llama_seq_id   seq_id,
-             llama_pos   p0,
-             llama_pos   p1,
-             llama_pos   delta) {
-    if (!kv) {
-        return;
-    }
-
-    kv->seq_add(seq_id, p0, p1, delta);
-}
-
-void llama_kv_cache_seq_div(
-        llama_kv_cache * kv,
-          llama_seq_id   seq_id,
-             llama_pos   p0,
-             llama_pos   p1,
-                   int   d) {
-    if (!kv) {
-        return;
-    }
-
-    kv->seq_div(seq_id, p0, p1, d);
-}
-
-llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
-    if (!kv) {
-        return 0;
-    }
-
-    return kv->seq_pos_max(seq_id);
-}
-
-void llama_kv_cache_defrag(llama_kv_cache * kv) {
-    if (!kv) {
-        return;
-    }
-
-    kv->defrag();
-}
-
-bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
-    if (!kv) {
-        return false;
-    }
-
-    return kv->get_can_shift();
-}
-
 //
 // kv cache view
 //
@@ -1393,7 +1287,7 @@ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t
         /*.n_cells            = */ 0,
         /*.n_seq_max          = */ n_seq_max,
         /*.token_count        = */ 0,
-        /*.used_cells         = */ llama_kv_cache_used_cells(&kv),
+        /*.used_cells         = */ kv.get_used_cells(),
         /*.max_contiguous     = */ 0,
         /*.max_contiguous_idx = */ -1,
         /*.cells              = */ nullptr,
index ff0ba3540d6e2e62b3d1fb5b08e7e6f72ce1b21c..56c74035ae1b9d8d1a12c7ef77cec66938c61861 100644 (file)
@@ -20,8 +20,8 @@ struct llama_kv_cache : public llama_memory_i {
     virtual void restore() = 0; // call if batch processing fails - restores the cache state
     virtual void commit() = 0;  // call after successful batch processing - clears any pending state
 
-    virtual int32_t  get_n_tokens()   const = 0;
-    virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
+    virtual int32_t get_n_tokens()   const = 0;
+    virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
 
     virtual bool get_can_shift() const = 0;
 
@@ -89,8 +89,8 @@ public:
                      uint32_t   kv_size,
                          bool   offload);
 
-    int32_t  get_n_tokens()   const override;
-    uint32_t get_used_cells() const override;
+    int32_t get_n_tokens()   const override;
+    int32_t get_used_cells() const override;
 
     size_t total_size() const;
 
@@ -109,7 +109,7 @@ public:
     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
-    llama_pos seq_pos_max(llama_seq_id seq_id) override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
     bool get_can_shift() const override;
 
@@ -204,48 +204,6 @@ private:
 //    using llama_kv_cache_unified::llama_kv_cache_unified;
 //};
 
-// TODO: maybe become part of the public llama_kv_cache in the future
-int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
-
-int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
-
-void llama_kv_cache_clear(llama_kv_cache * kv);
-
-bool llama_kv_cache_seq_rm(
-        llama_kv_cache * kv,
-          llama_seq_id   seq_id,
-             llama_pos   p0,
-             llama_pos   p1);
-
-void llama_kv_cache_seq_cp(
-        llama_kv_cache * kv,
-          llama_seq_id   seq_id_src,
-          llama_seq_id   seq_id_dst,
-             llama_pos   p0,
-             llama_pos   p1);
-
-void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
-
-void llama_kv_cache_seq_add(
-        llama_kv_cache * kv,
-          llama_seq_id   seq_id,
-             llama_pos   p0,
-             llama_pos   p1,
-             llama_pos   delta);
-
-void llama_kv_cache_seq_div(
-        llama_kv_cache * kv,
-          llama_seq_id   seq_id,
-             llama_pos   p0,
-             llama_pos   p1,
-                   int   d);
-
-llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
-
-void llama_kv_cache_defrag(llama_kv_cache * kv);
-
-bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
-
 //
 // kv cache view
 //
index 69e6e34ca4516f5fc274eb42df903b074060c548..dfa8c4e90fc2a90c61c9da74bc84a9d7eafb3fc5 100644 (file)
@@ -15,7 +15,7 @@ public:
     virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) = 0;
     virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
 
-    virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
+    virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
 
     virtual bool get_can_edit() const = 0;
 };