]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
memory : migrate from llama_kv_cache to more generic llama_memory (#14006)
authorGeorgi Gerganov <redacted>
Thu, 5 Jun 2025 12:29:22 +0000 (15:29 +0300)
committerGitHub <redacted>
Thu, 5 Jun 2025 12:29:22 +0000 (15:29 +0300)
* memory : merge llama_kv_cache into llama_memory + new `llama_memory` API

ggml-ci

* context : fix casts

ggml-ci

include/llama.h
src/CMakeLists.txt
src/llama-context.cpp
src/llama-context.h
src/llama-graph.h
src/llama-kv-cache-recurrent.h
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.h
src/llama-kv-cache.cpp [deleted file]
src/llama-kv-cache.h [deleted file]
src/llama-memory.h

index da0f652cfd63a409014cae7bff57956611f9e817..21808c881c7b3ff27a84ac9171ea59442a9f5d4d 100644 (file)
@@ -61,7 +61,10 @@ extern "C" {
     struct llama_model;
     struct llama_context;
     struct llama_sampler;
-    struct llama_kv_cache;
+
+    typedef struct llama_memory_i * llama_memory_t;
+
+    struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
 
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
@@ -493,9 +496,11 @@ extern "C" {
     DEPRECATED(LLAMA_API int32_t llama_n_vocab    (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
 
     LLAMA_API const struct llama_model * llama_get_model   (const struct llama_context * ctx);
-    LLAMA_API    struct llama_kv_cache * llama_get_kv_self (      struct llama_context * ctx);
+    LLAMA_API           llama_memory_t   llama_get_memory  (const struct llama_context * ctx);
     LLAMA_API  enum llama_pooling_type   llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
 
+    DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
+
     LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
     LLAMA_API enum llama_rope_type       llama_model_rope_type(const struct llama_model * model);
 
@@ -609,7 +614,78 @@ extern "C" {
                          int32_t   il_end);
 
     //
-    // KV cache
+    // Memory
+    //
+
+    // Clear the memory contents
+    LLAMA_API void llama_memory_clear(llama_memory_t mem);
+
+    // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
+    // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
+    // seq_id < 0 : match any sequence
+    // p0 < 0     : [0,  p1]
+    // p1 < 0     : [p0, inf)
+    LLAMA_API bool llama_memory_seq_rm(
+            llama_memory_t mem,
+              llama_seq_id seq_id,
+                 llama_pos p0,
+                 llama_pos p1);
+
+    // Copy all tokens that belong to the specified sequence to another sequence
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
+    LLAMA_API void llama_memory_seq_cp(
+            llama_memory_t mem,
+              llama_seq_id seq_id_src,
+              llama_seq_id seq_id_dst,
+                 llama_pos p0,
+                 llama_pos p1);
+
+    // Removes all tokens that do not belong to the specified sequence
+    LLAMA_API void llama_memory_seq_keep(
+            llama_memory_t mem,
+              llama_seq_id seq_id);
+
+    // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
+    LLAMA_API void llama_memory_seq_add(
+            llama_memory_t mem,
+              llama_seq_id seq_id,
+                 llama_pos p0,
+                 llama_pos p1,
+                 llama_pos delta);
+
+    // Integer division of the positions by factor of `d > 1`
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
+    LLAMA_API void llama_memory_seq_div(
+            llama_memory_t mem,
+              llama_seq_id seq_id,
+                 llama_pos p0,
+                 llama_pos p1,
+                       int d);
+
+    // Returns the smallest position present in the memory for the specified sequence
+    // This is typically non-zero only for SWA caches
+    // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
+    // Return -1 if the sequence is empty
+    LLAMA_API llama_pos llama_memory_seq_pos_min(
+            llama_memory_t mem,
+              llama_seq_id seq_id);
+
+    // Returns the largest position present in the memory for the specified sequence
+    // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
+    // Return -1 if the sequence is empty
+    LLAMA_API llama_pos llama_memory_seq_pos_max(
+            llama_memory_t mem,
+              llama_seq_id seq_id);
+
+    // Check if the memory supports shifting
+    LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
+
+    //
+    // KV cache for self-attention (TODO: deprecate in favor of llama_memory)
     //
 
     // Returns the number of tokens in the KV cache (slow, use only for debug)
@@ -623,7 +699,7 @@ extern "C" {
 
     // Clear the KV cache - both cell info is erased and KV data is zeroed
     LLAMA_API void llama_kv_self_clear(
-            struct llama_context * ctx);
+                struct llama_context * ctx);
 
     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
     // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
@@ -694,14 +770,14 @@ extern "C" {
     // Defragment the KV cache
     // This will be applied:
     //   - lazily on next llama_decode()
-    LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
+    DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
             "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
 
     // Check if the context supports KV cache shifting
     LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
 
     // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
-    LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
+    DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
             "simply remove this call, updates are applied lazily on the next llama_decode()");
 
     //
@@ -709,7 +785,7 @@ extern "C" {
     //
 
     // Returns the *actual* size in bytes of the state
-    // (logits, embedding and kv_cache)
+    // (logits, embedding and memory)
     // Only use when saving the state, not when restoring it, otherwise the size may be too small.
     LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
     LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@@ -765,12 +841,12 @@ extern "C" {
                           size_t   n_token_count),
         "use llama_state_save_file instead");
 
-    // Get the exact size needed to copy the KV cache of a single sequence
+    // Get the exact size needed to copy the state of a single sequence
     LLAMA_API size_t llama_state_seq_get_size(
             struct llama_context * ctx,
                     llama_seq_id   seq_id);
 
-    // Copy the KV cache of a single sequence into the specified buffer
+    // Copy the state of a single sequence into the specified buffer
     LLAMA_API size_t llama_state_seq_get_data(
             struct llama_context * ctx,
                          uint8_t * dst,
@@ -836,16 +912,16 @@ extern "C" {
     // For encode-decoder contexts, processes the batch using the encoder.
     // Can store the encoder output internally for later use by the decoder's cross-attention layers.
     //   0 - success
-    // < 0 - error. the KV cache state is restored to the state before this call
+    // < 0 - error. the memory state is restored to the state before this call
     LLAMA_API int32_t llama_encode(
             struct llama_context * ctx,
               struct llama_batch   batch);
 
     // Process a batch of tokens.
-    // Requires KV cache.
+    // Requires the context to have a memory.
     // For encode-decoder contexts, processes the batch using the decoder.
     // Positive return values does not mean a fatal error, but rather a warning.
-    // Upon non-zero return values, the KV cache state is restored to the state before this call
+    // Upon non-zero return values, the memory state is restored to the state before this call
     //    0 - success
     //    1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
     //    2 - aborted
index d20bd4fe2333d7747f782188fbc58a93dae3d946..70be604e4b0d336a30b24d9cf0601cde62b91a91 100644 (file)
@@ -20,7 +20,6 @@ add_library(llama
             llama-hparams.cpp
             llama-impl.cpp
             llama-io.cpp
-            llama-kv-cache.cpp
             llama-kv-cache-unified.cpp
             llama-kv-cache-unified-iswa.cpp
             llama-kv-cache-recurrent.cpp
index f1b43b9ccaaafa95abf68173972e42db59b3012f..c29fe7e4ce41dc8caf52075e463149d968051924 100644 (file)
@@ -2,9 +2,9 @@
 
 #include "llama-impl.h"
 #include "llama-io.h"
+#include "llama-memory.h"
 #include "llama-mmap.h"
 #include "llama-model.h"
-#include "llama-kv-cache.h"
 
 #include <cinttypes>
 #include <cstring>
@@ -277,10 +277,9 @@ llama_context::llama_context(
         int n_nodes_tg  = -1;
 
         // simulate full KV cache
-        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
-        const auto kv_state = kv_self->init_full();
-        if (!kv_state) {
+        const auto mstate = memory->init_full();
+        if (!mstate) {
             throw std::runtime_error("failed to initialize KV cache");
         }
 
@@ -288,7 +287,7 @@ llama_context::llama_context(
 
         // reserve pp graph first so that buffers are only allocated once
         {
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
@@ -299,7 +298,7 @@ llama_context::llama_context(
 
         // reserve with tg graph to get the number of splits and nodes
         {
-            auto * gf = graph_reserve(1, 1, 1, kv_state.get());
+            auto * gf = graph_reserve(1, 1, 1, mstate.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute tg buffers");
             }
@@ -310,7 +309,7 @@ llama_context::llama_context(
 
         // reserve again with pp graph to avoid ggml-alloc reallocations during inference
         {
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
@@ -419,14 +418,8 @@ uint32_t llama_context::n_threads_batch() const {
     return cparams.n_threads_batch;
 }
 
-llama_kv_cache * llama_context::get_kv_self() {
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-    return kv_self;
-}
-
-const llama_kv_cache * llama_context::get_kv_self() const {
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-    return kv_self;
+llama_memory_t llama_context::get_memory() const {
+    return memory.get();
 }
 
 void llama_context::kv_self_defrag_sched() {
@@ -442,15 +435,13 @@ bool llama_context::kv_self_update(bool optimize) {
         return false;
     }
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
     {
         // TODO: remove in the future
         optimize |= memory_force_optimize;
         memory_force_optimize = false;
 
-        const auto kv_state = kv_self->init_update(this, optimize);
-        switch (kv_state->get_status()) {
+        const auto mstate = memory->init_update(this, optimize);
+        switch (mstate->get_status()) {
             case LLAMA_MEMORY_STATUS_SUCCESS:
                 {
                     // noop
@@ -468,23 +459,25 @@ bool llama_context::kv_self_update(bool optimize) {
                 }
         }
 
-        if (!kv_state->apply()) {
+        if (!mstate->apply()) {
             LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
         }
     }
 
-    // if the KV cache did any computation, we have to reserve a new worst-case graph
-    const auto kv_state = kv_self->init_full();
-    if (!kv_state) {
-        throw std::runtime_error("failed to initialize memory state");
-    }
+    // if the memory module did any computation, we have to reserve a new worst-case graph
+    {
+        const auto mstate = memory->init_full();
+        if (!mstate) {
+            throw std::runtime_error("failed to initialize memory state");
+        }
 
-    const uint32_t n_seqs   = cparams.n_seq_max;
-    const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+        const uint32_t n_seqs   = cparams.n_seq_max;
+        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
-    auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
-    if (!gf) {
-        LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
+        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
+        if (!gf) {
+            LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
+        }
     }
 
     return true;
@@ -912,10 +905,8 @@ int llama_context::decode(llama_batch & inp_batch) {
         }
     }
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
     // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
 
     const llama_batch & batch = batch_allocr.batch;
 
@@ -977,21 +968,21 @@ int llama_context::decode(llama_batch & inp_batch) {
     // handle any pending defrags/shifts
     kv_self_update(false);
 
-    llama_memory_state_ptr kv_state;
+    llama_memory_state_ptr mstate;
 
     while (true) {
-        kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
-        if (!kv_state) {
+        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
+        if (!mstate) {
             return -2;
         }
 
-        switch (kv_state->get_status()) {
+        switch (mstate->get_status()) {
             case LLAMA_MEMORY_STATUS_SUCCESS:
                 {
                 } break;
             case LLAMA_MEMORY_STATUS_NO_UPDATE:
                 {
-                    LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
+                    LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
 
                     return -2;
                 }
@@ -1031,7 +1022,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     int64_t n_outputs_prev = 0;
 
     do {
-        const auto & ubatch = kv_state->get_ubatch();
+        const auto & ubatch = mstate->get_ubatch();
 
         // count the outputs in this u_batch
         {
@@ -1054,7 +1045,7 @@ int llama_context::decode(llama_batch & inp_batch) {
         ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
         ggml_status status;
-        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
+        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
 
         if (!res) {
             // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1076,7 +1067,7 @@ int llama_context::decode(llama_batch & inp_batch) {
 
                 LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
 
-                llama_kv_self_seq_rm(this, s, pos_min[s], -1);
+                memory->seq_rm(s, pos_min[s], -1);
             }
 
             switch (status) {
@@ -1170,7 +1161,7 @@ int llama_context::decode(llama_batch & inp_batch) {
         }
 
         n_outputs_prev += n_outputs;
-    } while (kv_state->next());
+    } while (mstate->next());
 
     // set to total number of outputs in the batch, for use in llama_get_logits_ith
     n_outputs = n_outputs_all;
@@ -1179,7 +1170,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     {
         bool sorted_output = true;
 
-        auto & out_ids = kv_state->out_ids();
+        auto & out_ids = mstate->out_ids();
 
         GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
 
@@ -1847,11 +1838,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
         }
     }
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-    if (kv_self != nullptr) {
+    if (memory != nullptr) {
         LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
-        kv_self->state_write(io);
+        memory->state_write(io);
     }
 
     return io.n_bytes();
@@ -1938,9 +1927,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
     if (memory) {
         LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
 
-        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-        kv_self->state_read(io);
+        memory->state_read(io);
     }
 
     return io.n_bytes();
@@ -1950,9 +1937,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
     GGML_UNUSED(seq_id);
 
     if (memory) {
-        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-        kv_self->state_write(io, seq_id);
+        memory->state_write(io, seq_id);
     }
 
     return io.n_bytes();
@@ -1962,9 +1947,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
     GGML_UNUSED(seq_id);
 
     if (memory) {
-        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-        kv_self->state_read(io, seq_id);
+        memory->state_read(io, seq_id);
     }
 
     return io.n_bytes();
@@ -2069,9 +2052,7 @@ void llama_context::opt_epoch_iter(
     const uint32_t n_batch  = std::min(this->n_batch(),  n_ctx);
     const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-    kv_self->clear();
+    memory->clear();
 
     for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
         batch.n_tokens = n_batch;
@@ -2094,8 +2075,8 @@ void llama_context::opt_epoch_iter(
 
         int64_t n_outputs_all = n_tokens_all;
 
-        auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
-        if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
+        auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
+        if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
             break;
         }
@@ -2108,17 +2089,17 @@ void llama_context::opt_epoch_iter(
 
         uint32_t pos_batch = 0;
         do {
-            const auto & ubatch = kv_state->get_ubatch();
+            const auto & ubatch = mstate->get_ubatch();
 
             n_outputs = ubatch.n_tokens;
 
-            if (!kv_state->apply()) {
+            if (!mstate->apply()) {
                 LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
                 break;
             }
 
             auto * gf = graph_init();
-            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
+            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
 
             struct ggml_context * ctx_compute_opt;
             {
@@ -2153,7 +2134,7 @@ void llama_context::opt_epoch_iter(
             ggml_free(ctx_compute_opt);
 
             pos_batch += ubatch.n_tokens;
-        } while (kv_state->next());
+        } while (mstate->next());
     }
 }
 
@@ -2314,8 +2295,9 @@ const llama_model * llama_get_model(const llama_context * ctx) {
     return &ctx->get_model();
 }
 
+// deprecated
 llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
-    return ctx->get_kv_self();
+    return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
 }
 
 // deprecated
@@ -2435,13 +2417,82 @@ int32_t llama_apply_adapter_cvec(
     return res ? 0 : -1;
 }
 
+//
+// memory
+//
+
+llama_memory_t llama_get_memory(const struct llama_context * ctx) {
+    return ctx->get_memory();
+}
+
+void llama_memory_clear(llama_memory_t mem) {
+    mem->clear();
+}
+
+bool llama_memory_seq_rm(
+        llama_memory_t mem,
+          llama_seq_id seq_id,
+             llama_pos p0,
+             llama_pos p1) {
+    return mem->seq_rm(seq_id, p0, p1);
+}
+
+void llama_memory_seq_cp(
+        llama_memory_t mem,
+          llama_seq_id seq_id_src,
+          llama_seq_id seq_id_dst,
+             llama_pos p0,
+             llama_pos p1) {
+    mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_memory_seq_keep(
+        llama_memory_t mem,
+          llama_seq_id seq_id) {
+    mem->seq_keep(seq_id);
+}
+
+void llama_memory_seq_add(
+        llama_memory_t mem,
+          llama_seq_id seq_id,
+             llama_pos p0,
+             llama_pos p1,
+             llama_pos delta) {
+    mem->seq_add(seq_id, p0, p1, delta);
+}
+
+void llama_memory_seq_div(
+        llama_memory_t mem,
+          llama_seq_id seq_id,
+             llama_pos p0,
+             llama_pos p1,
+                   int d) {
+    mem->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_memory_seq_pos_min(
+        llama_memory_t mem,
+          llama_seq_id seq_id) {
+    return mem->seq_pos_min(seq_id);
+}
+
+llama_pos llama_memory_seq_pos_max(
+        llama_memory_t mem,
+          llama_seq_id seq_id) {
+    return mem->seq_pos_max(seq_id);
+}
+
+bool llama_memory_can_shift(llama_memory_t mem) {
+    return mem->get_can_shift();
+}
+
 //
 // kv cache
 //
 
 // deprecated
 int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
-    const auto * kv = ctx->get_kv_self();
+    const auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return 0;
     }
@@ -2463,7 +2514,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
 // deprecated
 // note: this is the same as above - will be removed anyway, so it's ok
 int32_t llama_kv_self_used_cells(const llama_context * ctx) {
-    const auto * kv = ctx->get_kv_self();
+    const auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return 0;
     }
@@ -2483,12 +2534,12 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
 }
 
 void llama_kv_self_clear(llama_context * ctx) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->clear();
+    llama_memory_clear(kv);
 }
 
 bool llama_kv_self_seq_rm(
@@ -2496,12 +2547,12 @@ bool llama_kv_self_seq_rm(
          llama_seq_id   seq_id,
             llama_pos   p0,
             llama_pos   p1) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return true;
     }
 
-    return kv->seq_rm(seq_id, p0, p1);
+    return llama_memory_seq_rm(kv, seq_id, p0, p1);
 }
 
 void llama_kv_self_seq_cp(
@@ -2510,21 +2561,21 @@ void llama_kv_self_seq_cp(
          llama_seq_id   seq_id_dst,
             llama_pos   p0,
             llama_pos   p1) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+    llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
 }
 
 void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->seq_keep(seq_id);
+    llama_memory_seq_keep(kv, seq_id);
 }
 
 void llama_kv_self_seq_add(
@@ -2533,12 +2584,12 @@ void llama_kv_self_seq_add(
             llama_pos   p0,
             llama_pos   p1,
             llama_pos   delta) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->seq_add(seq_id, p0, p1, delta);
+    llama_memory_seq_add(kv, seq_id, p0, p1, delta);
 }
 
 void llama_kv_self_seq_div(
@@ -2547,30 +2598,30 @@ void llama_kv_self_seq_div(
             llama_pos   p0,
             llama_pos   p1,
                   int   d) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->seq_div(seq_id, p0, p1, d);
+    llama_memory_seq_div(kv, seq_id, p0, p1, d);
 }
 
 llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
-    const auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return -1;
     }
 
-    return kv->seq_pos_min(seq_id);
+    return llama_memory_seq_pos_min(kv, seq_id);
 }
 
 llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
-    const auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return -1;
     }
 
-    return kv->seq_pos_max(seq_id);
+    return llama_memory_seq_pos_max(kv, seq_id);
 }
 
 // deprecated
@@ -2580,12 +2631,12 @@ void llama_kv_self_defrag(llama_context * ctx) {
 }
 
 bool llama_kv_self_can_shift(const llama_context * ctx) {
-    const auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return false;
     }
 
-    return kv->get_can_shift();
+    return llama_memory_can_shift(kv);
 }
 
 // llama state API
index c1c7efb31fe32d0e8305ee1db9f678f47fbbc571..2e0da8c83bd59b5eca04b3c6d6fdfdf7f1a8fc30 100644 (file)
 #include <vector>
 
 struct llama_model;
-struct llama_kv_cache;
 
 class llama_io_read_i;
 class llama_io_write_i;
 
-class llama_memory_i;
-class llama_memory_state_i;
+struct llama_memory_i;
+struct llama_memory_state_i;
 
 struct llama_context {
     // init scheduler and compute buffers, reserve worst-case graphs
@@ -47,8 +46,7 @@ struct llama_context {
     uint32_t n_threads()       const;
     uint32_t n_threads_batch() const;
 
-          llama_kv_cache * get_kv_self();
-    const llama_kv_cache * get_kv_self() const;
+    llama_memory_t get_memory() const;
 
     // return true of the KV cache was updated
     // TODO: remove
index d1c5dd1bf036f6ed1a8272d4da29b67901086f5d..2b1cfa5b7e2e778021c06a19c271e70248c88e78 100644 (file)
@@ -17,7 +17,7 @@ struct ggml_tensor;
 struct llama_ubatch;
 struct llama_cparams;
 
-class llama_memory_state_i;
+struct llama_memory_state_i;
 
 class llama_kv_cache_unified_state;
 class llama_kv_cache_unified_iswa_state;
index b32f258fbffbf9fd2a02e8b0b56f437ad5ddae5d..cb813dfe8be19265d2aa1f1d1a77d2a893f090ce 100644 (file)
@@ -2,7 +2,7 @@
 
 #include "llama-batch.h"
 #include "llama-graph.h"
-#include "llama-kv-cache.h"
+#include "llama-memory.h"
 
 #include <set>
 #include <vector>
@@ -13,7 +13,7 @@
 
 // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
 //       see the implementation of llama_kv_cache_unified_state_i for an example how to do it
-class llama_kv_cache_recurrent : public llama_kv_cache {
+class llama_kv_cache_recurrent : public llama_memory_i {
 public:
     llama_kv_cache_recurrent(
             const llama_model & model,
@@ -29,6 +29,16 @@ public:
     // llama_memory_i
     //
 
+    llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled,
+            bool logits_all) override;
+
+    llama_memory_state_ptr init_full() override;
+
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+
     void clear() override;
 
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
@@ -40,20 +50,6 @@ public:
     llama_pos seq_pos_min(llama_seq_id seq_id) const override;
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
-    //
-    // llama_kv_cache
-    //
-
-    llama_memory_state_ptr init_batch(
-            const llama_batch & batch,
-            uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) override;
-
-    llama_memory_state_ptr init_full() override;
-
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
-
     bool prepare(const std::vector<llama_ubatch> & ubatches);
 
     // find a contiguous slot of kv cells and emplace the ubatch there
index cba5bbe95197e66a4a6808814df6d8e1f08b6b76..3fabcd6b835db9cad42d9d9489ebadbb80a46922 100644 (file)
@@ -11,7 +11,7 @@
 // utilizes two instances of llama_kv_cache_unified
 //   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
 
-class llama_kv_cache_unified_iswa : public llama_kv_cache {
+class llama_kv_cache_unified_iswa : public llama_memory_i {
 public:
     llama_kv_cache_unified_iswa(
             const llama_model & model,
@@ -31,21 +31,6 @@ public:
     // llama_memory_i
     //
 
-    void clear() override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    //
-    // llama_kv_cache
-    //
-
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
@@ -58,6 +43,17 @@ public:
 
     bool get_can_shift() const override;
 
+    void clear() override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
index 6ff388a88b1d58690d81215f18158e7d4c9a8589..d01a9abd7610a9f8f298c54b4ac8f43179557186 100644 (file)
@@ -2,8 +2,8 @@
 
 #include "llama-batch.h"
 #include "llama-graph.h"
-#include "llama-kv-cache.h"
 #include "llama-kv-cells.h"
+#include "llama-memory.h"
 
 #include <unordered_map>
 #include <vector>
@@ -17,7 +17,7 @@ struct llama_context;
 // llama_kv_cache_unified
 //
 
-class llama_kv_cache_unified : public llama_kv_cache {
+class llama_kv_cache_unified : public llama_memory_i {
 public:
     static uint32_t get_padding(const llama_cparams & cparams);
 
@@ -56,21 +56,6 @@ public:
     // llama_memory_i
     //
 
-    void clear() override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    //
-    // llama_kv_cache
-    //
-
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
@@ -83,6 +68,17 @@ public:
 
     bool get_can_shift() const override;
 
+    void clear() override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
deleted file mode 100644 (file)
index aefd23e..0000000
+++ /dev/null
@@ -1 +0,0 @@
-#include "llama-kv-cache.h"
diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h
deleted file mode 100644 (file)
index 17a5e5c..0000000
+++ /dev/null
@@ -1,41 +0,0 @@
-#pragma once
-
-#include "llama.h"
-#include "llama-memory.h"
-
-class llama_io_write_i;
-class llama_io_read_i;
-
-struct llama_kv_cache : public llama_memory_i {
-    virtual ~llama_kv_cache() = default;
-
-    // TODO: move the init_ interfaces to llama_memory_i
-
-    // split the input batch into a set of ubatches and verify that they can fit into the cache
-    // return a state object containing the ubatches and KV cache state required to process them
-    // check the llama_memory_state_i::get_status() for the result
-    virtual llama_memory_state_ptr init_batch(
-            const llama_batch & batch,
-            uint32_t n_ubatch,
-            bool embd_pooled,
-            bool logits_all) = 0;
-
-    // simulate full cache, used for allocating worst-case compute buffers
-    virtual llama_memory_state_ptr init_full() = 0;
-
-    // prepare for any pending memory updates, such as shifts, defrags, etc.
-    // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
-    virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
-
-    // getters
-    virtual bool get_can_shift() const = 0;
-
-    bool get_can_edit() const override { return get_can_shift(); }
-
-    //
-    // state write/read
-    //
-
-    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
-    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) = 0;
-};
index ab0d399c4eb602ebe2e6af74afd6f14b1a18aaf3..5993b59bef0716152ffd6ad76faf32597df15856 100644 (file)
@@ -7,6 +7,9 @@
 
 struct llama_ubatch;
 
+class llama_io_write_i;
+class llama_io_read_i;
+
 struct llama_memory_params {
     // kv cache
     ggml_type type_k;
@@ -16,28 +19,6 @@ struct llama_memory_params {
     bool swa_full;
 };
 
-// general concept of LLM memory
-// the KV cache is a type of LLM memory, but there can be other types
-class llama_memory_i {
-public:
-    virtual ~llama_memory_i() = default;
-
-    virtual void clear() = 0;
-
-    virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
-    virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
-    virtual void seq_keep(llama_seq_id seq_id) = 0;
-    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) = 0;
-    virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
-
-    virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
-    virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
-
-    virtual bool get_can_edit() const = 0;
-};
-
-using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
-
 enum llama_memory_status {
     LLAMA_MEMORY_STATUS_SUCCESS = 0,
     LLAMA_MEMORY_STATUS_NO_UPDATE,
@@ -58,8 +39,7 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
 // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
 //
 // TODO: rename to llama_memory_context_i ?
-class llama_memory_state_i {
-public:
+struct llama_memory_state_i {
     virtual ~llama_memory_state_i() = default;
 
     // consume the current ubatch from the state and proceed to the next one
@@ -81,3 +61,57 @@ public:
 };
 
 using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
+
+// general concept of LLM memory
+// the KV cache is a type of LLM memory, but there can be other types
+struct llama_memory_i {
+    virtual ~llama_memory_i() = default;
+
+    // split the input batch into a set of ubatches and verify that they can fit into the cache
+    // return a state object containing the ubatches and KV cache state required to process them
+    // check the llama_memory_state_i::get_status() for the result
+    virtual llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled,
+            bool logits_all) = 0;
+
+    // simulate full cache, used for allocating worst-case compute buffers
+    virtual llama_memory_state_ptr init_full() = 0;
+
+    // prepare for any pending memory updates, such as shifts, defrags, etc.
+    // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
+    virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
+
+    // getters
+    virtual bool get_can_shift() const = 0;
+
+    //
+    // ops
+    //
+
+    virtual void clear() = 0;
+
+    virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
+    virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
+    virtual void seq_keep(llama_seq_id seq_id) = 0;
+    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) = 0;
+    virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
+
+    virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
+    virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
+
+    //
+    // state write/read
+    //
+
+    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
+    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) = 0;
+};
+
+using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
+
+// TODO: temporary until the llama_kv_cache is removed from the public API
+struct llama_kv_cache : public llama_memory_i {
+    virtual ~llama_kv_cache() = default;
+};