]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Tue, 10 Jun 2025 07:12:44 +0000 (10:12 +0300)
committerGeorgi Gerganov <redacted>
Tue, 10 Jun 2025 09:40:33 +0000 (12:40 +0300)
ggml-ci

23 files changed:
examples/talk-llama/CMakeLists.txt
examples/talk-llama/llama-arch.cpp
examples/talk-llama/llama-arch.h
examples/talk-llama/llama-context.cpp
examples/talk-llama/llama-context.h
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-kv-cache-recurrent.cpp
examples/talk-llama/llama-kv-cache-recurrent.h
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
examples/talk-llama/llama-kv-cache-unified-iswa.h
examples/talk-llama/llama-kv-cache-unified.cpp
examples/talk-llama/llama-kv-cache-unified.h
examples/talk-llama/llama-kv-cache.cpp [deleted file]
examples/talk-llama/llama-kv-cells.h
examples/talk-llama/llama-memory.cpp
examples/talk-llama/llama-memory.h
examples/talk-llama/llama-mmap.cpp
examples/talk-llama/llama-model-loader.cpp
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama-model.h
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama.h

index da190e33e718292ec9919945ef8d047c3a4f29db..d53546386647881e1cdde5e0cc3c1fbca149bee6 100644 (file)
@@ -16,7 +16,6 @@ if (WHISPER_SDL2)
         llama-hparams.cpp
         llama-impl.cpp
         llama-io.cpp
-        llama-kv-cache.cpp
         llama-kv-cache-unified.cpp
         llama-kv-cache-unified-iswa.cpp
         llama-kv-cache-recurrent.cpp
index c0590e105c8895b9446703752241fd1e2d758368..43fa60a8070b765f09dff5ef73124bb342cb0d52 100644 (file)
@@ -200,7 +200,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_TOKENIZER_HF_JSON,              "tokenizer.huggingface.json"              },
     { LLM_KV_TOKENIZER_RWKV,                 "tokenizer.rwkv.world"                    },
     { LLM_KV_TOKENIZER_CHAT_TEMPLATE,        "tokenizer.chat_template"                 },
-    { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,      "tokenizer.chat_template.%s"              },
     { LLM_KV_TOKENIZER_FIM_PRE_ID,           "tokenizer.ggml.fim_pre_token_id"         },
     { LLM_KV_TOKENIZER_FIM_SUF_ID,           "tokenizer.ggml.fim_suf_token_id"         },
     { LLM_KV_TOKENIZER_FIM_MID_ID,           "tokenizer.ggml.fim_mid_token_id"         },
@@ -1707,8 +1706,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
 LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
 
 std::string LLM_KV::operator()(llm_kv kv) const {
-    return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix)
-        : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
+    std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
+
+    if (suffix != nullptr) {
+        name += ".";
+        name += suffix;
+    }
+
+    return name;
 }
 
 std::string LLM_TN_IMPL::str() const {
index 930cb4eca33ab7dac049d7c45b94836ff310defc..f3825528aefdb93cf9f571fb479de36fe5aadee4 100644 (file)
@@ -196,7 +196,6 @@ enum llm_kv {
     LLM_KV_TOKENIZER_HF_JSON,
     LLM_KV_TOKENIZER_RWKV,
     LLM_KV_TOKENIZER_CHAT_TEMPLATE,
-    LLM_KV_TOKENIZER_CHAT_TEMPLATE_N,
     LLM_KV_TOKENIZER_FIM_PRE_ID,
     LLM_KV_TOKENIZER_FIM_SUF_ID,
     LLM_KV_TOKENIZER_FIM_MID_ID,
index 4ab57438794005fb3416b792156994a48b70d448..b130b484bcf6fd4c2c38a1b64de13962efe55be6 100644 (file)
@@ -2,9 +2,9 @@
 
 #include "llama-impl.h"
 #include "llama-io.h"
+#include "llama-memory.h"
 #include "llama-mmap.h"
 #include "llama-model.h"
-#include "llama-kv-cache.h"
 
 #include <cinttypes>
 #include <cstring>
@@ -123,7 +123,7 @@ llama_context::llama_context(
                 __func__, n_ctx_per_seq, hparams.n_ctx_train);
     }
 
-    if (!params.swa_full && cparams.n_seq_max > 1) {
+    if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) {
         LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
                 __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
     }
@@ -277,10 +277,9 @@ llama_context::llama_context(
         int n_nodes_tg  = -1;
 
         // simulate full KV cache
-        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
-        const auto kv_state = kv_self->init_full();
-        if (!kv_state) {
+        const auto mstate = memory->init_full();
+        if (!mstate) {
             throw std::runtime_error("failed to initialize KV cache");
         }
 
@@ -288,7 +287,7 @@ llama_context::llama_context(
 
         // reserve pp graph first so that buffers are only allocated once
         {
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
@@ -299,7 +298,7 @@ llama_context::llama_context(
 
         // reserve with tg graph to get the number of splits and nodes
         {
-            auto * gf = graph_reserve(1, 1, 1, kv_state.get());
+            auto * gf = graph_reserve(1, 1, 1, mstate.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute tg buffers");
             }
@@ -310,7 +309,7 @@ llama_context::llama_context(
 
         // reserve again with pp graph to avoid ggml-alloc reallocations during inference
         {
-            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
+            auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
             if (!gf) {
                 throw std::runtime_error("failed to allocate compute pp buffers");
             }
@@ -419,40 +418,68 @@ uint32_t llama_context::n_threads_batch() const {
     return cparams.n_threads_batch;
 }
 
-llama_kv_cache * llama_context::get_kv_self() {
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-    return kv_self;
+llama_memory_t llama_context::get_memory() const {
+    return memory.get();
 }
 
-const llama_kv_cache * llama_context::get_kv_self() const {
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-    return kv_self;
+// deprecated
+void llama_context::kv_self_defrag_sched() {
+    if (!memory) {
+        return;
+    }
+
+    memory_force_optimize = true;
 }
 
-bool llama_context::kv_self_update() {
+// deprecated
+bool llama_context::kv_self_update(bool optimize) {
     if (!memory) {
         return false;
     }
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+    {
+        // TODO: remove in the future
+        optimize |= memory_force_optimize;
+        memory_force_optimize = false;
 
-    if (!kv_self->update(*this)) {
-        // no updates have been performed
-        return false;
-    }
+        const auto mstate = memory->init_update(this, optimize);
+        switch (mstate->get_status()) {
+            case LLAMA_MEMORY_STATUS_SUCCESS:
+                {
+                    // noop
+                } break;
+            case LLAMA_MEMORY_STATUS_NO_UPDATE:
+                {
+                    // no updates need to be performed
+                    return false;
+                }
+            case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+            case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+                {
+                    LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
+                    return false;
+                }
+        }
 
-    // if the KV cache did any computation, we have to reserve a new worst-case graph
-    const auto kv_state = kv_self->init_full();
-    if (!kv_state) {
-        throw std::runtime_error("failed to initialize KV cache");
+        if (!mstate->apply()) {
+            LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
+        }
     }
 
-    const uint32_t n_seqs   = cparams.n_seq_max;
-    const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+    // if the memory module did any computation, we have to reserve a new worst-case graph
+    {
+        const auto mstate = memory->init_full();
+        if (!mstate) {
+            throw std::runtime_error("failed to initialize memory state");
+        }
 
-    auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
-    if (!gf) {
-        LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
+        const uint32_t n_seqs   = cparams.n_seq_max;
+        const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+
+        auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
+        if (!gf) {
+            LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
+        }
     }
 
     return true;
@@ -814,16 +841,17 @@ int llama_context::encode(llama_batch & inp_batch) {
                 } break;
             case LLAMA_POOLING_TYPE_RANK:
                 {
-                    // extract the rerank score - a single float per sequence
+                    // extract the rerank score - n_cls_out floats per sequence
                     auto & embd_seq_out = embd_seq;
+                    const uint32_t n_cls_out = hparams.n_cls_out;
 
                     for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
                         const llama_seq_id seq_id = ubatch.seq_id[s][0];
                         if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
                             continue;
                         }
-                        embd_seq_out[seq_id].resize(1);
-                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+                        embd_seq_out[seq_id].resize(n_cls_out);
+                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
                     }
                 } break;
             case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -880,10 +908,8 @@ int llama_context::decode(llama_batch & inp_batch) {
         }
     }
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
     // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
 
     const llama_batch & batch = batch_allocr.batch;
 
@@ -940,42 +966,49 @@ int llama_context::decode(llama_batch & inp_batch) {
         n_outputs_all = 1;
     }
 
-    // handle any pending defrags/shifts
-    kv_self_update();
+    bool did_optimize = false;
 
-    llama_memory_state_ptr kv_state;
+    // handle any pending defrags/shifts
+    kv_self_update(false);
 
-    bool did_defrag = false;
+    llama_memory_state_ptr mstate;
 
     while (true) {
-        kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
-        if (!kv_state) {
+        mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
+        if (!mstate) {
             return -2;
         }
 
-        switch (kv_state->get_status()) {
+        switch (mstate->get_status()) {
             case LLAMA_MEMORY_STATUS_SUCCESS:
                 {
                 } break;
+            case LLAMA_MEMORY_STATUS_NO_UPDATE:
+                {
+                    LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
+
+                    return -2;
+                }
             case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
                 {
-                    if (!did_defrag) {
-                        did_defrag = true;
+                    if (!did_optimize) {
+                        did_optimize = true;
 
-                        kv_self->defrag_sched(-1.0f);
-                        if (kv_self_update()) {
-                            LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
+                        if (kv_self_update(true)) {
+                            LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
 
                             continue;
                         }
                     }
 
-                    LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
+                    LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
 
                     return 1;
                 }
             case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
                 {
+                    LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
+
                     return -2;
                 }
         }
@@ -992,7 +1025,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     int64_t n_outputs_prev = 0;
 
     do {
-        const auto & ubatch = kv_state->get_ubatch();
+        const auto & ubatch = mstate->get_ubatch();
 
         // count the outputs in this u_batch
         {
@@ -1015,11 +1048,14 @@ int llama_context::decode(llama_batch & inp_batch) {
         ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
         ggml_status status;
-        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
+        const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
 
         if (!res) {
             // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
-            llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES] = { std::numeric_limits<llama_pos>::max() };
+            llama_pos pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
+            for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
+                pos_min[s] = std::numeric_limits<llama_pos>::max();
+            }
 
             for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
                 const auto & seq_id = ubatch.seq_id[i][0];
@@ -1034,7 +1070,7 @@ int llama_context::decode(llama_batch & inp_batch) {
 
                 LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
 
-                llama_kv_self_seq_rm(this, s, pos_min[s], -1);
+                memory->seq_rm(s, pos_min[s], -1);
             }
 
             switch (status) {
@@ -1128,7 +1164,7 @@ int llama_context::decode(llama_batch & inp_batch) {
         }
 
         n_outputs_prev += n_outputs;
-    } while (kv_state->next());
+    } while (mstate->next());
 
     // set to total number of outputs in the batch, for use in llama_get_logits_ith
     n_outputs = n_outputs_all;
@@ -1137,7 +1173,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     {
         bool sorted_output = true;
 
-        auto & out_ids = kv_state->out_ids();
+        auto & out_ids = mstate->out_ids();
 
         GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
 
@@ -1189,11 +1225,6 @@ int llama_context::decode(llama_batch & inp_batch) {
     // wait for the computation to finish (automatically done when obtaining the model output)
     //synchronize();
 
-    // decide if we need to defrag the kv cache
-    if (cparams.defrag_thold > 0.0f) {
-        kv_self->defrag_sched(cparams.defrag_thold);
-    }
-
     // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
     // overlap with device computation.
     ggml_backend_sched_reset(sched.get());
@@ -1810,11 +1841,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
         }
     }
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-    if (kv_self != nullptr) {
+    if (memory != nullptr) {
         LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
-        kv_self->state_write(io);
+        memory->state_write(io);
     }
 
     return io.n_bytes();
@@ -1901,9 +1930,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
     if (memory) {
         LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
 
-        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-        kv_self->state_read(io);
+        memory->state_read(io);
     }
 
     return io.n_bytes();
@@ -1913,9 +1940,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
     GGML_UNUSED(seq_id);
 
     if (memory) {
-        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-        kv_self->state_write(io, seq_id);
+        memory->state_write(io, seq_id);
     }
 
     return io.n_bytes();
@@ -1925,9 +1950,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
     GGML_UNUSED(seq_id);
 
     if (memory) {
-        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-        kv_self->state_read(io, seq_id);
+        memory->state_read(io, seq_id);
     }
 
     return io.n_bytes();
@@ -2032,9 +2055,7 @@ void llama_context::opt_epoch_iter(
     const uint32_t n_batch  = std::min(this->n_batch(),  n_ctx);
     const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
-
-    kv_self->clear();
+    memory->clear(true);
 
     for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
         batch.n_tokens = n_batch;
@@ -2057,8 +2078,8 @@ void llama_context::opt_epoch_iter(
 
         int64_t n_outputs_all = n_tokens_all;
 
-        auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
-        if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
+        auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
+        if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
             LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
             break;
         }
@@ -2071,17 +2092,17 @@ void llama_context::opt_epoch_iter(
 
         uint32_t pos_batch = 0;
         do {
-            const auto & ubatch = kv_state->get_ubatch();
+            const auto & ubatch = mstate->get_ubatch();
 
             n_outputs = ubatch.n_tokens;
 
-            if (!kv_state->apply()) {
+            if (!mstate->apply()) {
                 LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
                 break;
             }
 
             auto * gf = graph_init();
-            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
+            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
 
             struct ggml_context * ctx_compute_opt;
             {
@@ -2116,7 +2137,7 @@ void llama_context::opt_epoch_iter(
             ggml_free(ctx_compute_opt);
 
             pos_batch += ubatch.n_tokens;
-        } while (kv_state->next());
+        } while (mstate->next());
     }
 }
 
@@ -2277,13 +2298,14 @@ const llama_model * llama_get_model(const llama_context * ctx) {
     return &ctx->get_model();
 }
 
+// deprecated
 llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
-    return ctx->get_kv_self();
+    return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
 }
 
 // deprecated
 void llama_kv_self_update(llama_context * ctx) {
-    ctx->kv_self_update();
+    ctx->kv_self_update(false);
 }
 
 enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
@@ -2398,13 +2420,118 @@ int32_t llama_apply_adapter_cvec(
     return res ? 0 : -1;
 }
 
+//
+// memory
+//
+
+llama_memory_t llama_get_memory(const struct llama_context * ctx) {
+    return ctx->get_memory();
+}
+
+void llama_memory_clear(llama_memory_t mem, bool data) {
+    if (!mem) {
+        return;
+    }
+
+    mem->clear(data);
+}
+
+bool llama_memory_seq_rm(
+        llama_memory_t mem,
+          llama_seq_id seq_id,
+             llama_pos p0,
+             llama_pos p1) {
+    if (!mem) {
+        return true;
+    }
+
+    return mem->seq_rm(seq_id, p0, p1);
+}
+
+void llama_memory_seq_cp(
+        llama_memory_t mem,
+          llama_seq_id seq_id_src,
+          llama_seq_id seq_id_dst,
+             llama_pos p0,
+             llama_pos p1) {
+    if (!mem) {
+        return;
+    }
+
+    mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_memory_seq_keep(
+        llama_memory_t mem,
+          llama_seq_id seq_id) {
+    if (!mem) {
+        return;
+    }
+
+    mem->seq_keep(seq_id);
+}
+
+void llama_memory_seq_add(
+        llama_memory_t mem,
+          llama_seq_id seq_id,
+             llama_pos p0,
+             llama_pos p1,
+             llama_pos delta) {
+    if (!mem) {
+        return;
+    }
+
+    mem->seq_add(seq_id, p0, p1, delta);
+}
+
+void llama_memory_seq_div(
+        llama_memory_t mem,
+          llama_seq_id seq_id,
+             llama_pos p0,
+             llama_pos p1,
+                   int d) {
+    if (!mem) {
+        return;
+    }
+
+    mem->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_memory_seq_pos_min(
+        llama_memory_t mem,
+          llama_seq_id seq_id) {
+    if (!mem) {
+        return -1;
+    }
+
+    return mem->seq_pos_min(seq_id);
+}
+
+llama_pos llama_memory_seq_pos_max(
+        llama_memory_t mem,
+          llama_seq_id seq_id) {
+    if (!mem) {
+        return -1;
+    }
+
+    return mem->seq_pos_max(seq_id);
+}
+
+bool llama_memory_can_shift(llama_memory_t mem) {
+    if (!mem) {
+        return false;
+    }
+
+    return mem->get_can_shift();
+}
+
 //
 // kv cache
 //
 
 // deprecated
 int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
-    const auto * kv = ctx->get_kv_self();
+    const auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return 0;
     }
@@ -2426,7 +2553,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
 // deprecated
 // note: this is the same as above - will be removed anyway, so it's ok
 int32_t llama_kv_self_used_cells(const llama_context * ctx) {
-    const auto * kv = ctx->get_kv_self();
+    const auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return 0;
     }
@@ -2445,115 +2572,119 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
     return res;
 }
 
+// deprecated
 void llama_kv_self_clear(llama_context * ctx) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->clear();
+    llama_memory_clear(kv, true);
 }
 
+// deprecated
 bool llama_kv_self_seq_rm(
         llama_context * ctx,
          llama_seq_id   seq_id,
             llama_pos   p0,
             llama_pos   p1) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return true;
     }
 
-    return kv->seq_rm(seq_id, p0, p1);
+    return llama_memory_seq_rm(kv, seq_id, p0, p1);
 }
 
+// deprecated
 void llama_kv_self_seq_cp(
         llama_context * ctx,
          llama_seq_id   seq_id_src,
          llama_seq_id   seq_id_dst,
             llama_pos   p0,
             llama_pos   p1) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+    llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
 }
 
+// deprecated
 void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->seq_keep(seq_id);
+    llama_memory_seq_keep(kv, seq_id);
 }
 
+// deprecated
 void llama_kv_self_seq_add(
         llama_context * ctx,
          llama_seq_id   seq_id,
             llama_pos   p0,
             llama_pos   p1,
             llama_pos   delta) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->seq_add(seq_id, p0, p1, delta);
+    llama_memory_seq_add(kv, seq_id, p0, p1, delta);
 }
 
+// deprecated
 void llama_kv_self_seq_div(
         llama_context * ctx,
          llama_seq_id   seq_id,
             llama_pos   p0,
             llama_pos   p1,
                   int   d) {
-    auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return;
     }
 
-    kv->seq_div(seq_id, p0, p1, d);
+    llama_memory_seq_div(kv, seq_id, p0, p1, d);
 }
 
+// deprecated
 llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
-    const auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return -1;
     }
 
-    return kv->seq_pos_min(seq_id);
+    return llama_memory_seq_pos_min(kv, seq_id);
 }
 
+// deprecated
 llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
-    const auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return -1;
     }
 
-    return kv->seq_pos_max(seq_id);
+    return llama_memory_seq_pos_max(kv, seq_id);
 }
 
 // deprecated
 void llama_kv_self_defrag(llama_context * ctx) {
-    auto * kv = ctx->get_kv_self();
-    if (!kv) {
-        return;
-    }
-
     // force defrag
-    kv->defrag_sched(-1.0f);
+    ctx->kv_self_defrag_sched();
 }
 
+// deprecated
 bool llama_kv_self_can_shift(const llama_context * ctx) {
-    const auto * kv = ctx->get_kv_self();
+    auto * kv = llama_get_memory(ctx);
     if (!kv) {
         return false;
     }
 
-    return kv->get_can_shift();
+    return llama_memory_can_shift(kv);
 }
 
 // llama state API
index 3b880286bfd5de755467aea8eef849a3f13df766..2e0da8c83bd59b5eca04b3c6d6fdfdf7f1a8fc30 100644 (file)
 #include <vector>
 
 struct llama_model;
-struct llama_kv_cache;
 
 class llama_io_read_i;
 class llama_io_write_i;
 
-class llama_memory_i;
-class llama_memory_state_i;
+struct llama_memory_i;
+struct llama_memory_state_i;
 
 struct llama_context {
     // init scheduler and compute buffers, reserve worst-case graphs
@@ -47,12 +46,12 @@ struct llama_context {
     uint32_t n_threads()       const;
     uint32_t n_threads_batch() const;
 
-          llama_kv_cache * get_kv_self();
-    const llama_kv_cache * get_kv_self() const;
+    llama_memory_t get_memory() const;
 
     // return true of the KV cache was updated
     // TODO: remove
-    bool kv_self_update();
+    bool kv_self_update(bool optimize);
+    void kv_self_defrag_sched();
 
     enum llama_pooling_type pooling_type() const;
 
@@ -231,6 +230,9 @@ private:
 
     std::unique_ptr<llama_memory_i> memory;
 
+    // TODO: temporary, until the llama_kv_self_defrag() API is removed
+    bool memory_force_optimize = false;
+
     // decode output (2-dimensional array: [n_outputs][n_vocab])
     size_t  logits_size = 0; // capacity (of floats) for logits
     float * logits      = nullptr;
index 727e119e334f6093a00a26f8d1ea3fc056c3addd..27c9ab74be1125e6b7811c75a0a5ae92bf6be3a0 100644 (file)
@@ -659,6 +659,20 @@ ggml_tensor * llm_graph_context::build_ffn(
                 cur = ggml_mul(ctx0, x0, x1);
                 cb(cur, "ffn_mul", il);
             } break;
+        case LLM_FFN_GEGLU:
+            {
+                // Split into two equal parts
+                int64_t split_point = cur->ne[0] / 2;
+                // TODO: these conts should not be needed
+                ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                x0 = ggml_gelu(ctx0, x0);
+                cb(x0, "ffn_gelu", il);
+
+                cur = ggml_mul(ctx0, x0, x1);
+                cb(cur, "ffn_geglu", il);
+            } break;
     }
 
     if (gate && type_gate == LLM_FFN_PAR) {
@@ -769,9 +783,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
     cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
 
     if (weight_before_ffn) {
-        // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
-        ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
-        repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
+        // repeat cur to [n_embd, n_expert_used, n_tokens]
+        ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1);
         cur = ggml_mul(ctx0, repeated, weights);
         cb(cur, "ffn_moe_weighted", il);
     }
index d1c5dd1bf036f6ed1a8272d4da29b67901086f5d..28da6a5228bdcbb928d2d16c8da1e2a6bce7b027 100644 (file)
@@ -17,7 +17,7 @@ struct ggml_tensor;
 struct llama_ubatch;
 struct llama_cparams;
 
-class llama_memory_state_i;
+struct llama_memory_state_i;
 
 class llama_kv_cache_unified_state;
 class llama_kv_cache_unified_iswa_state;
@@ -36,6 +36,7 @@ enum llm_ffn_op_type {
     LLM_FFN_RELU,
     LLM_FFN_RELU_SQR,
     LLM_FFN_SWIGLU,
+    LLM_FFN_GEGLU,
 };
 
 enum llm_ffn_gate_type {
index 641eab2f316cef56d8be36c472118b3099e642e0..f5c6dcd66ce9e8a519d7158a28b9cd5fcd29568b 100644 (file)
@@ -1,6 +1,7 @@
 #include "llama-kv-cache-recurrent.h"
 
 #include "llama-impl.h"
+#include "llama-io.h"
 #include "llama-batch.h"
 #include "llama-model.h"
 
@@ -116,18 +117,21 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
     }
 }
 
-void llama_kv_cache_recurrent::clear() {
+void llama_kv_cache_recurrent::clear(bool data) {
     for (int32_t i = 0; i < (int32_t) size; ++i) {
         cells[i].pos = -1;
         cells[i].seq_id.clear();
         cells[i].src = -1;
         cells[i].tail = -1;
     }
+
     head = 0;
     used = 0;
 
-    for (auto & buf : bufs) {
-        ggml_backend_buffer_clear(buf.get(), 0);
+    if (data) {
+        for (auto & buf : bufs) {
+            ggml_backend_buffer_clear(buf.get(), 0);
+        }
     }
 }
 
@@ -386,6 +390,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
     return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
 }
 
+llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
+    GGML_UNUSED(lctx);
+    GGML_UNUSED(optimize);
+
+    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
+}
+
 bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
     // simply remember the full state because it is very small for this type of cache
     // TODO: optimize
@@ -419,17 +430,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
     return success;
 }
 
-bool llama_kv_cache_recurrent::update(llama_context & lctx) {
-    GGML_UNUSED(lctx);
-    // noop
-    return false;
-}
-
-void llama_kv_cache_recurrent::defrag_sched(float thold) {
-    GGML_UNUSED(thold);
-    // noop
-}
-
 bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
     const uint32_t n_tokens = ubatch.n_tokens;
     const uint32_t n_seqs   = ubatch.n_seqs;
@@ -726,7 +726,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
 
     if (!res) {
         if (seq_id == -1) {
-            clear();
+            clear(true);
         } else {
             seq_rm(seq_id, -1, -1);
         }
@@ -883,7 +883,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
             return false;
         }
 
-        clear();
+        clear(true);
 
         for (uint32_t i = 0; i < cell_count; ++i) {
             kv_cell & cell = cells[i];
index a178ae85c146a8e8f36579668aa2de3e67012cca..d1da1225655fa4b18d0ae055df131ebff2c9d4cc 100644 (file)
@@ -2,7 +2,7 @@
 
 #include "llama-batch.h"
 #include "llama-graph.h"
-#include "llama-kv-cache.h"
+#include "llama-memory.h"
 
 #include <set>
 #include <vector>
@@ -13,7 +13,7 @@
 
 // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
 //       see the implementation of llama_kv_cache_unified_state_i for an example how to do it
-class llama_kv_cache_recurrent : public llama_kv_cache {
+class llama_kv_cache_recurrent : public llama_memory_i {
 public:
     llama_kv_cache_recurrent(
             const llama_model & model,
@@ -29,21 +29,6 @@ public:
     // llama_memory_i
     //
 
-    void clear() override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    //
-    // llama_kv_cache
-    //
-
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
@@ -52,9 +37,18 @@ public:
 
     llama_memory_state_ptr init_full() override;
 
-    bool update(llama_context & lctx) override;
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+
+    void clear(bool data) override;
 
-    void defrag_sched(float thold) override;
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
     bool prepare(const std::vector<llama_ubatch> & ubatches);
 
index 0eb04563435467b5845668f33bf3f15b96adc47c..28d18265476497e42b93087b8c336a360d6e0d13 100644 (file)
@@ -52,9 +52,9 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
             hparams.n_swa, hparams.swa_type);
 }
 
-void llama_kv_cache_unified_iswa::clear() {
-    kv_base->clear();
-    kv_swa ->clear();
+void llama_kv_cache_unified_iswa::clear(bool data) {
+    kv_base->clear(data);
+    kv_swa ->clear(data);
 }
 
 bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
 
     assert(heads_base.size() == heads_swa.size());
 
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(
             this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
 }
 
 llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
-    return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
 }
 
-bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
-    bool res = false;
-
-    res = res | kv_base->update(lctx);
-    res = res | kv_swa ->update(lctx);
-
-    return res;
-}
-
-void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
-    kv_base->defrag_sched(thold);
-    kv_swa ->defrag_sched(thold);
+llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
 }
 
 bool llama_kv_cache_unified_iswa::get_can_shift() const {
@@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
-        llama_memory_status status,
-        llama_kv_cache_unified_iswa * kv) : status(status) {
-    state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
-    state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
+        llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
+    state_base = kv->get_base()->init_full();
+    state_swa  = kv->get_swa ()->init_full();
+
+    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+}
+
+llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
+        llama_kv_cache_unified_iswa * kv,
+        llama_context * lctx,
+        bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
+    state_base = kv->get_base()->init_update(lctx, optimize);
+    state_swa  = kv->get_swa ()->init_update(lctx, optimize);
+
+    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
 }
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
-        llama_memory_status status,
         llama_kv_cache_unified_iswa * kv,
         llama_sbatch sbatch,
         std::vector<uint32_t> heads_base,
         std::vector<uint32_t> heads_swa,
         std::vector<llama_ubatch> ubatches)
-    : status(status),
-    sbatch(std::move(sbatch)),
-    ubatches(std::move(ubatches)) {
-        // note: here we copy the ubatches. not sure if this is ideal
-        state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
-        state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa),  this->ubatches));
-    }
+        : status(LLAMA_MEMORY_STATUS_SUCCESS),
+        sbatch(std::move(sbatch)),
+        ubatches(std::move(ubatches)) {
+    // note: here we copy the ubatches. not sure if this is ideal
+    state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
+    state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa),  this->ubatches));
+
+    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+}
 
 llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
 
@@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
 
 const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
     return ubatches[i_next];
 }
 
 const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return state_base.get();
+    return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
 }
 
 const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa()  const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return state_swa.get();
+    return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
 }
index 8b067da038af6af2e8322aa98f2920f3f23d369a..3dbf33ed7b960d3985804e84a63eccddbb308198 100644 (file)
@@ -11,7 +11,7 @@
 // utilizes two instances of llama_kv_cache_unified
 //   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
 
-class llama_kv_cache_unified_iswa : public llama_kv_cache {
+class llama_kv_cache_unified_iswa : public llama_memory_i {
 public:
     llama_kv_cache_unified_iswa(
             const llama_model & model,
@@ -31,21 +31,6 @@ public:
     // llama_memory_i
     //
 
-    void clear() override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    //
-    // llama_kv_cache
-    //
-
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
@@ -54,12 +39,21 @@ public:
 
     llama_memory_state_ptr init_full() override;
 
-    bool update(llama_context & lctx) override;
-
-    void defrag_sched(float thold) override;
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -86,12 +80,16 @@ public:
 
     // used to create a full-cache state
     llama_kv_cache_unified_iswa_state(
-            llama_memory_status status,
             llama_kv_cache_unified_iswa * kv);
 
+    // used to create an update state
+    llama_kv_cache_unified_iswa_state(
+            llama_kv_cache_unified_iswa * kv,
+            llama_context * lctx,
+            bool optimize);
+
     // used to create a state from a batch
     llama_kv_cache_unified_iswa_state(
-            llama_memory_status status,
             llama_kv_cache_unified_iswa * kv,
             llama_sbatch sbatch,
             std::vector<uint32_t> heads_base,
@@ -120,7 +118,7 @@ public:
     const llama_kv_cache_unified_state * get_swa()  const;
 
 private:
-    const llama_memory_status status;
+    llama_memory_status status;
 
     //llama_kv_cache_unified_iswa * kv;
 
@@ -131,6 +129,6 @@ private:
 
     std::vector<llama_ubatch> ubatches;
 
-    std::unique_ptr<llama_kv_cache_unified_state> state_base;
-    std::unique_ptr<llama_kv_cache_unified_state> state_swa;
+    llama_memory_state_ptr state_base;
+    llama_memory_state_ptr state_swa;
 };
index a817154769a32306740c826ece3d6d4f60aee5d3..3566d5fd4d72bdf5a7e8a968f18197e7743815ae 100644 (file)
@@ -1,6 +1,7 @@
 #include "llama-kv-cache-unified.h"
 
 #include "llama-impl.h"
+#include "llama-io.h"
 #include "llama-model.h"
 #include "llama-context.h"
 
@@ -128,13 +129,15 @@ llama_kv_cache_unified::llama_kv_cache_unified(
     }
 }
 
-void llama_kv_cache_unified::clear() {
+void llama_kv_cache_unified::clear(bool data) {
     cells.reset();
 
     head = 0;
 
-    for (auto & buf : bufs) {
-        ggml_backend_buffer_clear(buf.get(), 0);
+    if (data) {
+        for (auto & buf : bufs) {
+            ggml_backend_buffer_clear(buf.get(), 0);
+        }
     }
 }
 
@@ -149,12 +152,27 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    for (uint32_t i = 0; i < cells.size(); ++i) {
-        if (!cells.pos_in(i, p0, p1)) {
-            continue;
+    if (seq_id >= 0) {
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.pos_in(i, p0, p1)) {
+                continue;
+            }
+
+            if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
+                if (new_head == cells.size()) {
+                    new_head = i;
+                }
+            }
         }
+    } else {
+        // match any sequence
+        for (uint32_t i = 0; i < cells.size(); ++i) {
+            if (!cells.pos_in(i, p0, p1)) {
+                continue;
+            }
+
+            cells.rm(i);
 
-        if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
             if (new_head == cells.size()) {
                 new_head = i;
             }
@@ -305,16 +323,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
         return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
     }
 
-    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
+    return std::make_unique<llama_kv_cache_unified_state>(
             this, std::move(sbatch), std::move(heads), std::move(ubatches));
 }
 
 llama_memory_state_ptr llama_kv_cache_unified::init_full() {
-    return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
+    return std::make_unique<llama_kv_cache_unified_state>(this);
+}
+
+llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
+    bool do_shift = get_has_shift();
+
+    defrag_info dinfo;
+
+    // see if we need to defrag
+    {
+        bool do_defrag = optimize;
+
+        const auto thold = lctx->get_cparams().defrag_thold;
+
+        if (!do_defrag && thold > 0.0f) {
+            const auto n_kv = cells.used_max_p1();
+
+            // - do not defrag small contexts (i.e. < 2048 tokens)
+            // - count the padding towards the number of used tokens
+            const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
+
+            if (fragmentation > thold) {
+                LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
+
+                do_defrag = true;
+            }
+        }
+
+        if (do_defrag) {
+            dinfo = defrag_prepare(lctx->graph_max_nodes());
+        }
+    }
+
+    return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
 }
 
-std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
-    std::vector<uint32_t> res;
+llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
+    llama_kv_cache_unified::ubatch_heads res;
 
     struct state {
         uint32_t head_old; // old position of the head, before placing the ubatch
@@ -359,12 +410,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
     return res;
 }
 
-bool llama_kv_cache_unified::update(llama_context & lctx) {
+bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
     bool updated = false;
 
-    auto * sched = lctx.get_sched();
+    auto * sched = lctx->get_sched();
 
-    if (cells.get_has_shift()) {
+    if (do_shift) {
         if (!get_can_shift()) {
             GGML_ABORT("The current KV cache / model configuration does not support K-shift");
         }
@@ -375,9 +426,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
         if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
             ggml_backend_sched_reset(sched);
 
-            auto * gf = lctx.graph_init();
+            auto * gf = lctx->graph_init();
 
-            auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
+            auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
             if (!res) {
                 LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
                 return updated;
@@ -390,7 +441,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
 
             res->set_inputs(nullptr);
 
-            if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+            if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
                 LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
                 return updated;
             }
@@ -401,54 +452,53 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
         cells.reset_shift();
     }
 
-    if (do_defrag) {
+    if (!dinfo.empty()) {
         LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
 
-        if (defrag_prepare(lctx.graph_max_nodes())) {
-            ggml_backend_sched_reset(sched);
-
-            auto * gf = lctx.graph_init();
-
-            auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
-            if (!res) {
-                LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
-                return updated;
-            }
+        // apply moves:
+        {
+            const auto n_kv = dinfo.ids.size();
 
-            if (!ggml_backend_sched_alloc_graph(sched, gf)) {
-                LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
-                return updated;
-            }
+            for (uint32_t i = 0; i < n_kv; ++i) {
+                assert(dinfo.ids[i] <= n_kv);
 
-            res->set_inputs(nullptr);
+                if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) {
+                    continue;
+                }
 
-            if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
-                LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
-                return updated;
+                cells.mv(i, dinfo.ids[i]);
             }
 
-            updated = true;
+            // reset the head so we can find the first free slot during the next ubatch
+            head = 0;
         }
 
-        do_defrag = false;
-    }
+        ggml_backend_sched_reset(sched);
 
-    return updated;
-}
+        auto * gf = lctx->graph_init();
 
-void llama_kv_cache_unified::defrag_sched(float thold) {
-    const auto n_kv = cells.used_max_p1();
+        auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
+        if (!res) {
+            LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
+            return updated;
+        }
+
+        if (!ggml_backend_sched_alloc_graph(sched, gf)) {
+            LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
+            return updated;
+        }
 
-    // - do not defrag small contexts (i.e. < 2048 tokens)
-    // - count the padding towards the number of used tokens
-    const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
+        res->set_inputs(nullptr);
 
-    // queue defragmentation for next llama_kv_cache_update
-    if (fragmentation > thold) {
-        LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
+        if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
+            LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
+            return updated;
+        }
 
-        do_defrag = true;
+        updated = true;
     }
+
+    return updated;
 }
 
 int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
@@ -597,6 +647,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
     return cells.size();
 }
 
+bool llama_kv_cache_unified::get_has_shift() const {
+    return cells.get_has_shift();
+}
+
 uint32_t llama_kv_cache_unified::get_n_kv() const {
     return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
 }
@@ -890,11 +944,9 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
     const auto & n_embd_head_k = hparams.n_embd_head_k;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
-    //GGML_ASSERT(kv_self->size == n_ctx);
-
     auto inp = std::make_unique<llm_graph_input_k_shift>(this);
 
-    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
+    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
     ggml_set_input(inp->k_shift);
 
     for (const auto & layer : layers) {
@@ -926,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
 }
 
 llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
-        const llama_cparams & cparams,
-               ggml_context * ctx,
-                ggml_cgraph * gf) const {
+                const llama_cparams & cparams,
+                       ggml_context * ctx,
+                        ggml_cgraph * gf,
+                  const defrag_info & dinfo) const {
     auto res = std::make_unique<llm_graph_result>();
 
-    const auto & ids = defrag_info.ids;
+    const auto & ids = dinfo.ids;
 
 #if 0
     // CPU defrag
@@ -1072,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
     return res;
 }
 
-bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
     const uint32_t n_layer = layers.size();
 
     const uint32_t n_kv   = cells.used_max_p1();
@@ -1093,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
-    //
-    //  cell i moves to ids[i]
-    //
-    //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
-    //
-    auto & ids = defrag_info.ids;
+    defrag_info res;
+    auto & ids = res.ids;
 
-    ids.clear();
     ids.resize(n_kv, n_kv);
 
     for (uint32_t i0 = 0; i0 < n_used; ++i0) {
@@ -1164,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
             // this cell goes to (i0 + nf)
             ids[i1] = i0 + nf;
 
-            // move the cell meta data
-            cells.mv(i1, i0 + nf);
-
-            head = n_used;
-
             if (!cont) {
                 n_moves++;
                 cont = true;
@@ -1191,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
     }
 
     if (n_moves == 0) {
-        return false;
+        return {};
     }
 
     LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
 
     LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
 
-    return true;
+    return res;
 }
 
 bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
@@ -1276,7 +1319,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
 
     if (!res) {
         if (seq_id == -1) {
-            clear();
+            clear(true);
         } else {
             seq_rm(seq_id, -1, -1);
         }
@@ -1457,7 +1500,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
             return false;
         }
 
-        clear();
+        clear(true);
 
         for (uint32_t i = 0; i < cell_count; ++i) {
             llama_pos pos;
@@ -1621,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
 llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
 
 llama_kv_cache_unified_state::llama_kv_cache_unified_state(
-            llama_memory_status status,
-            llama_kv_cache_unified * kv) : status(status), kv(kv) {
-        n_kv = kv->get_size();
-        head = 0;
-    }
+        llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
+    n_kv = kv->get_size();
+    head = 0;
+}
 
 llama_kv_cache_unified_state::llama_kv_cache_unified_state(
-            llama_memory_status status,
-            llama_kv_cache_unified * kv,
-            llama_sbatch sbatch,
-            std::vector<uint32_t> heads,
-            std::vector<llama_ubatch> ubatches)
-            : status(status),
-              kv(kv),
-              sbatch(std::move(sbatch)),
-              heads(std::move(heads)),
-              ubatches(std::move(ubatches)) {
+        llama_kv_cache_unified * kv,
+        llama_context * lctx,
+        bool do_shift,
+        defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
+    if (!do_shift && dinfo.empty()) {
+        status = LLAMA_MEMORY_STATUS_NO_UPDATE;
     }
+}
+
+llama_kv_cache_unified_state::llama_kv_cache_unified_state(
+        llama_kv_cache_unified * kv,
+        llama_sbatch sbatch,
+        llama_kv_cache_unified::ubatch_heads heads,
+        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
+}
 
 llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
 
@@ -1655,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
 bool llama_kv_cache_unified_state::apply() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
+    // no ubatches -> this is a KV cache update
+    if (ubatches.empty()) {
+        kv->update(lctx, do_shift, dinfo);
+
+        return true;
+    }
+
     kv->apply_ubatch(heads[i_next], ubatches[i_next]);
 
     n_kv = kv->get_n_kv();
index 1f1d44b97c2ac713e54bdeb0f63ec40a7d9f275e..49f410ef6ecabf11aa267b0d0568f66934a41a05 100644 (file)
@@ -2,8 +2,8 @@
 
 #include "llama-batch.h"
 #include "llama-graph.h"
-#include "llama-kv-cache.h"
 #include "llama-kv-cells.h"
+#include "llama-memory.h"
 
 #include <unordered_map>
 #include <vector>
@@ -17,13 +17,26 @@ struct llama_context;
 // llama_kv_cache_unified
 //
 
-class llama_kv_cache_unified : public llama_kv_cache {
+class llama_kv_cache_unified : public llama_memory_i {
 public:
     static uint32_t get_padding(const llama_cparams & cparams);
 
     // this callback is used to filter out layers that should not be included in the cache
     using layer_filter_cb = std::function<bool(int32_t il)>;
 
+    using ubatch_heads = std::vector<uint32_t>;
+
+    struct defrag_info {
+        bool empty() const {
+            return ids.empty();
+        }
+
+        // contains information about which cell moves where:
+        //  - cell i moves to ids[i]
+        //  - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
+        std::vector<uint32_t> ids;
+    };
+
     llama_kv_cache_unified(
             const llama_model &  model,
               layer_filter_cb && filter,
@@ -43,21 +56,6 @@ public:
     // llama_memory_i
     //
 
-    void clear() override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    //
-    // llama_kv_cache
-    //
-
     llama_memory_state_ptr init_batch(
             const llama_batch & batch,
             uint32_t n_ubatch,
@@ -66,12 +64,21 @@ public:
 
     llama_memory_state_ptr init_full() override;
 
-    bool update(llama_context & lctx) override;
-
-    void defrag_sched(float thold) override;
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
 
     bool get_can_shift() const override;
 
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -83,6 +90,8 @@ public:
 
     uint32_t get_size() const;
 
+    bool get_has_shift() const;
+
     //
     // graph_build API
     //
@@ -103,7 +112,9 @@ public:
 
     // find places for the provided ubatches in the cache, returns the head locations
     // return empty vector on failure
-    std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
+    ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
+
+    bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
 
     // return the cell position where we can insert the ubatch
     // return -1 on failure to find a contiguous slot of kv cells
@@ -133,8 +144,7 @@ private:
         ggml_tensor * v;
     };
 
-    bool do_defrag = false;
-    bool v_trans   = true;  // the value tensor is transposed
+    bool v_trans = true;  // the value tensor is transposed
 
     // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
     // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
@@ -160,13 +170,8 @@ private:
     // model layer id -> KV cache layer id
     std::unordered_map<int32_t, int32_t> map_layer_ids;
 
-    // defrag
-    struct {
-        std::vector<uint32_t> ids;
-    } defrag_info;
-
-    // return true if cells have been moved
-    bool defrag_prepare(int32_t n_max_nodes);
+    // return non-empty vector if cells have been moved
+    defrag_info defrag_prepare(int32_t n_max_nodes) const;
 
     size_t total_size() const;
 
@@ -192,7 +197,8 @@ private:
     llm_graph_result_ptr build_graph_defrag(
             const llama_cparams & cparams,
                    ggml_context * ctx,
-                    ggml_cgraph * gf) const;
+                    ggml_cgraph * gf,
+              const defrag_info & dinfo) const;
 
     void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
     void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@@ -203,20 +209,29 @@ private:
 
 class llama_kv_cache_unified_state : public llama_memory_state_i {
 public:
+    // some shorthands
+    using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
+    using defrag_info  = llama_kv_cache_unified::defrag_info;
+
     // used for errors
     llama_kv_cache_unified_state(llama_memory_status status);
 
     // used to create a full-cache state
     llama_kv_cache_unified_state(
-            llama_memory_status status,
             llama_kv_cache_unified * kv);
 
-    // used to create a state from a batch
+    // used to create an update state
+    llama_kv_cache_unified_state(
+            llama_kv_cache_unified * kv,
+            llama_context * lctx,
+            bool do_shift,
+            defrag_info dinfo);
+
+    // used to create a decode state from a batch
     llama_kv_cache_unified_state(
-            llama_memory_status status,
             llama_kv_cache_unified * kv,
             llama_sbatch sbatch,
-            std::vector<uint32_t> heads,
+            ubatch_heads heads,
             std::vector<llama_ubatch> ubatches);
 
     virtual ~llama_kv_cache_unified_state();
@@ -253,16 +268,30 @@ public:
     void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
 private:
-    const llama_memory_status status;
+    llama_memory_status status;
 
     llama_kv_cache_unified * kv;
+    llama_context * lctx;
+
+    //
+    // update state
+    //
+
+    bool do_shift = false;
+
+    defrag_info dinfo;
+
+    //
+    // batch processing state
+    //
 
     llama_sbatch sbatch;
 
     // the index of the next ubatch to process
     size_t i_next = 0;
 
-    std::vector<uint32_t> heads;
+    ubatch_heads heads;
+
     std::vector<llama_ubatch> ubatches;
 
     //
diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp
deleted file mode 100644 (file)
index aefd23e..0000000
+++ /dev/null
@@ -1 +0,0 @@
-#include "llama-kv-cache.h"
index 9e2c4d927699d72300ed013b17135a9b12c63bdb..acf30aebec69b071bc0b36223771e71ce0a5f9fd 100644 (file)
@@ -80,6 +80,9 @@ public:
         assert(isrc < pos.size());
         assert(idst < pos.size());
 
+        assert(pos[idst] == -1);
+        assert(pos[isrc] != -1);
+
         pos  [idst] = pos  [isrc];
         shift[idst] = shift[isrc];
         seq  [idst] = seq  [isrc];
@@ -144,9 +147,10 @@ public:
         assert(pos[i] != -1);
 
         seq_pos_rm(i);
+        seq[i].reset();
 
         pos[i] = -1;
-        seq[i].reset();
+        shift[i] = 0;
 
         used.erase(i);
     }
@@ -164,6 +168,7 @@ public:
 
         if (seq[i].none()) {
             pos[i] = -1;
+            shift[i] = 0;
 
             used.erase(i);
 
@@ -192,6 +197,7 @@ public:
             seq[i].reset();
 
             pos[i] = -1;
+            shift[i] = 0;
 
             used.erase(i);
 
@@ -317,21 +323,20 @@ public:
         pos[i]   += d;
         shift[i] += d;
 
-        seq_pos_add(i);
-
         has_shift = true;
 
         if (pos[i] < 0) {
-            seq_pos_rm(i);
-
             seq[i].reset();
             pos[i] = -1;
+            shift[i] = 0;
 
             used.erase(i);
 
             return true;
         }
 
+        seq_pos_add(i);
+
         return false;
     }
 
index 10173253edfe448ad9aeba27be3a9deee484bff7..f1107672c6476411b04521db02379255328e7728 100644 (file)
@@ -1 +1,42 @@
 #include "llama-memory.h"
+
+llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
+    bool has_update = false;
+
+    switch (s0) {
+        case LLAMA_MEMORY_STATUS_SUCCESS:
+            {
+                has_update = true;
+                break;
+            }
+        case LLAMA_MEMORY_STATUS_NO_UPDATE:
+            {
+                break;
+            }
+        case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+        case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+            {
+                return s0;
+            }
+    }
+
+    switch (s1) {
+        case LLAMA_MEMORY_STATUS_SUCCESS:
+            {
+                has_update = true;
+                break;
+            }
+        case LLAMA_MEMORY_STATUS_NO_UPDATE:
+            {
+                break;
+            }
+        case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+        case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+            {
+                return s1;
+            }
+    }
+
+    // if either status has an update, then the combined status has an update
+    return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
+}
index b3799d66e8c170241284abbb4735cc729a553512..991aae781ba57003d2d975994f848ff69735a931 100644 (file)
@@ -7,6 +7,9 @@
 
 struct llama_ubatch;
 
+class llama_io_write_i;
+class llama_io_read_i;
+
 struct llama_memory_params {
     // kv cache
     ggml_type type_k;
@@ -16,32 +19,17 @@ struct llama_memory_params {
     bool swa_full;
 };
 
-// general concept of LLM memory
-// the KV cache is a type of LLM memory, but there can be other types
-class llama_memory_i {
-public:
-    virtual ~llama_memory_i() = default;
-
-    virtual void clear() = 0;
-
-    virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
-    virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
-    virtual void seq_keep(llama_seq_id seq_id) = 0;
-    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) = 0;
-    virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
-
-    virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
-    virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
-
-    virtual bool get_can_edit() const = 0;
-};
-
 enum llama_memory_status {
     LLAMA_MEMORY_STATUS_SUCCESS = 0,
+    LLAMA_MEMORY_STATUS_NO_UPDATE,
     LLAMA_MEMORY_STATUS_FAILED_PREPARE,
     LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
 };
 
+// helper function for combining the status of two memory states
+// useful for implementing hybrid memory types (e.g. iSWA)
+llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
+
 // the interface for managing the memory state during batch processing
 // this interface is implemented per memory type. see:
 //   - llama_kv_cache_unified_state
@@ -51,8 +39,7 @@ enum llama_memory_status {
 // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
 //
 // TODO: rename to llama_memory_context_i ?
-class llama_memory_state_i {
-public:
+struct llama_memory_state_i {
     virtual ~llama_memory_state_i() = default;
 
     // consume the current ubatch from the state and proceed to the next one
@@ -69,8 +56,63 @@ public:
     // get the current ubatch
     virtual const llama_ubatch & get_ubatch() const = 0;
 
-    // get the status of the memory state
+    // get the status of the memory state - used for error handling and checking if any updates would be applied
     virtual llama_memory_status get_status() const = 0;
 };
 
 using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
+
+// general concept of LLM memory
+// the KV cache is a type of LLM memory, but there can be other types
+struct llama_memory_i {
+    virtual ~llama_memory_i() = default;
+
+    // split the input batch into a set of ubatches and verify that they can fit into the cache
+    // return a state object containing the ubatches and KV cache state required to process them
+    // check the llama_memory_state_i::get_status() for the result
+    virtual llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled,
+            bool logits_all) = 0;
+
+    // simulate full cache, used for allocating worst-case compute buffers
+    virtual llama_memory_state_ptr init_full() = 0;
+
+    // prepare for any pending memory updates, such as shifts, defrags, etc.
+    // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
+    virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
+
+    // getters
+    virtual bool get_can_shift() const = 0;
+
+    //
+    // ops
+    //
+
+    // if data == true, the data buffers will also be cleared together with the metadata
+    virtual void clear(bool data) = 0;
+
+    virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
+    virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
+    virtual void seq_keep(llama_seq_id seq_id) = 0;
+    virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) = 0;
+    virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
+
+    virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
+    virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
+
+    //
+    // state write/read
+    //
+
+    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
+    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) = 0;
+};
+
+using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
+
+// TODO: temporary until the llama_kv_cache is removed from the public API
+struct llama_kv_cache : public llama_memory_i {
+    virtual ~llama_kv_cache() = default;
+};
index 9da97f1bc5057d830943307a57120e497176bb56..47497cf953fd3990d6e147e4837980bed6223b63 100644 (file)
@@ -401,7 +401,7 @@ struct llama_mmap::impl {
                 }
             }
 #else
-            throw std::runtime_error("PrefetchVirtualMemory unavailable");
+            LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n");
 #endif
         }
     }
index ddb1b03675b289acdf53f66c4e5cb1c2ad80b589..bd9e6da8832b78c7d5a1f4661ef84c33269bea10 100644 (file)
@@ -288,9 +288,10 @@ namespace GGUFMeta {
 
     template<typename T>
     bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
-        const int kid = gguf_find_key(meta.get(), key.c_str());
+        const gguf_context * ctx = meta.get();
+        const int kid = gguf_find_key(ctx, key.c_str());
 
-        if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
+        if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
             if (required) {
                 throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
             }
@@ -298,28 +299,40 @@ namespace GGUFMeta {
         }
 
         struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
+            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
 
         switch (arr_info.gt) {
             case GGUF_TYPE_UINT32:
-            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,  int32_t>::value) ||
-                                                (std::is_same<T, uint32_t>::value)); break;
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,    float>::value)); break;
+            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,     int32_t>::value) ||
+                                                (std::is_same<T,    uint32_t>::value)); break;
+            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,       float>::value)); break;
+            case GGUF_TYPE_STRING:  GGML_ASSERT((std::is_same<T, std::string>::value)); break;
             default:
-                throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
+                throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
         }
 
-        result.resize(arr_info.length);
-        result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
+        if constexpr (std::is_same<T, std::string>::value) {
+            const size_t n_items = gguf_get_arr_n(ctx, kid);
+            result.clear();
+
+            for (size_t i = 0; i < n_items; i++) {
+                const T value = gguf_get_arr_str(ctx, kid, i);
+                result.emplace_back(value);
+            }
+        } else {
+            result.resize(arr_info.length);
+            result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
+        }
 
         return true;
     }
 
     template<typename T, size_t N_MAX>
     bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) {
-        const int kid = gguf_find_key(meta.get(), key.c_str());
+        const gguf_context * ctx = meta.get();
+        const int kid = gguf_find_key(ctx, key.c_str());
 
-        if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) {
+        if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) {
             if (required) {
                 throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
             }
@@ -327,22 +340,32 @@ namespace GGUFMeta {
         }
 
         struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
+            GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid);
 
         switch (arr_info.gt) {
             case GGUF_TYPE_UINT32:
-            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,  int32_t>::value) ||
-                                                (std::is_same<T, uint32_t>::value)); break;
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,    float>::value)); break;
+            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,     int32_t>::value) ||
+                                                (std::is_same<T,    uint32_t>::value)); break;
+            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,       float>::value)); break;
+            case GGUF_TYPE_STRING:  GGML_ASSERT((std::is_same<T, std::string>::value)); break;
             default:
-                throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
+                throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str()));
         }
 
         if (arr_info.length > N_MAX) {
             throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
         }
 
-        std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
+        if constexpr (std::is_same<T, std::string>::value) {
+            const size_t n_items = gguf_get_arr_n(ctx, kid);
+
+            for (size_t i = 0; i < n_items; i++) {
+                const T value = gguf_get_arr_str(ctx, kid, i);
+                result[i] = value;
+            }
+        } else {
+            std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
+        }
 
         return true;
     }
@@ -352,6 +375,8 @@ namespace GGUFMeta {
         return get_arr(llm_kv(kid), result, required);
     }
 
+    template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
+
     template<typename T>
     bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
         auto it = kv_overrides.find(key);
index 50264a69aac0ebd882c54009d893a3402bbfd069..c41ee24507fca47f7cab1ca353c4edfd4ec7bd9a 100644 (file)
@@ -543,6 +543,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
     uint32_t n_vocab = 0;
     ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
 
+    // for classifier models
+    ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
+    if (!classifier_labels.empty()) {
+        hparams.n_cls_out = classifier_labels.size();
+    }
+
     // arch-specific KVs
     switch (arch) {
         case LLM_ARCH_LLAMA:
@@ -686,7 +692,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
                 ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
-                ml.get_arr_n(LLM_KV_CLASSIFIER_OUTPUT_LABELS, hparams.n_cls_out, false);
 
                 switch (hparams.n_layer) {
                     case 3:
@@ -956,6 +961,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     case 46: type = LLM_TYPE_27B; break;
                     default: type = LLM_TYPE_UNKNOWN;
                }
+
+                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173
+                hparams.f_attention_scale = type == LLM_TYPE_27B
+                    ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
+                    : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
             } break;
         case LLM_ARCH_GEMMA3:
             {
@@ -976,6 +986,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
 
+                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289
                 hparams.f_attention_scale = type == LLM_TYPE_27B
                     ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
                     : 1.0f / std::sqrt(float(hparams.n_embd_head_k));
@@ -4356,6 +4367,15 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
         LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
         LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
+
+        if (!classifier_labels.empty()) {
+            LLAMA_LOG_INFO("%s: n_cls_out        = %u\n", __func__, hparams.n_cls_out);
+
+            size_t i = 0;
+            for (auto label : classifier_labels) {
+                LLAMA_LOG_INFO("%s: cls_label[%2zu]    = %s\n", __func__, i++, label.c_str());
+            }
+        }
     }
 
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, type_name().c_str());
@@ -8484,14 +8504,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
-                switch (model.type) {
-                    case LLM_TYPE_2B:
-                    case LLM_TYPE_9B:
-                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break;
-                    default: GGML_ABORT("fatal error");
-                };
-                cb(Qcur, "Qcur_scaled", il);
+                Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
 
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
@@ -8632,9 +8645,12 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
+                // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315
+                Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale);
+
                 cur = build_attn(inp_attn, gf,
                         model.layers[il].wo, NULL,
-                        Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
             }
 
             cur = build_norm(cur,
@@ -13600,6 +13616,18 @@ int32_t llama_model_n_swa(const llama_model * model) {
     return model->hparams.n_swa;
 }
 
+uint32_t llama_model_n_cls_out(const struct llama_model * model) {
+    return model->hparams.n_cls_out;
+}
+
+const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) {
+    if (i < model->classifier_labels.size()) {
+        return model->classifier_labels[i].c_str();
+    }
+
+    return nullptr;
+}
+
 // deprecated
 int32_t llama_n_ctx_train(const llama_model * model) {
     return llama_model_n_ctx_train(model);
@@ -13760,7 +13788,7 @@ uint64_t llama_model_size(const llama_model * model) {
 }
 
 const char * llama_model_chat_template(const llama_model * model, const char * name) {
-    const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
+    const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)
         : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
     const auto & it = model->gguf_kv.find(key);
     if (it == model->gguf_kv.end()) {
index cbea2cb331b626f6ca2f829a186ec0822b20ce76..18b714620bbcf899e23828eb9653fcb80b728e49 100644 (file)
@@ -329,6 +329,9 @@ struct llama_model {
     llama_hparams hparams = {};
     llama_vocab   vocab;
 
+    // for classifier models
+    std::vector<std::string> classifier_labels;
+
     struct ggml_tensor * tok_embd   = nullptr;
     struct ggml_tensor * type_embd  = nullptr;
     struct ggml_tensor * pos_embd   = nullptr;
index d5a036a8c4413cb91c6d242ec2f049ced6b3cc62..ba2e1864ec0050ba58777b417ada98113809e3cd 100644 (file)
@@ -2080,9 +2080,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
 
         std::string model_name;
         std::string tokenizer_pre;
+        std::string general_arch;
 
         ml.get_key(LLM_KV_GENERAL_NAME,  model_name,    false);
         ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+        ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false);
 
         // model name to lowercase
         std::transform(model_name.begin(), model_name.end(), model_name.begin(),
@@ -2091,9 +2093,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             }
         );
 
-        // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
-            _set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
+        // set attributes by model/tokenizer/architecture name
+        if (false
+                || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
+                || _contains_any(general_arch, {"nomic-bert-moe"})
+           ) {
+            if (token_to_id.count("<mask>") == 0) {
+                LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
+            } else {
+                _set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
+            }
         } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
             for (auto id : cache_special_tokens) {
                 _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
index da0f652cfd63a409014cae7bff57956611f9e817..015a57898e22d2411296108fbc7446713f1dc2c9 100644 (file)
@@ -61,7 +61,10 @@ extern "C" {
     struct llama_model;
     struct llama_context;
     struct llama_sampler;
-    struct llama_kv_cache;
+
+    typedef struct llama_memory_i * llama_memory_t;
+
+    struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
 
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
@@ -493,9 +496,11 @@ extern "C" {
     DEPRECATED(LLAMA_API int32_t llama_n_vocab    (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
 
     LLAMA_API const struct llama_model * llama_get_model   (const struct llama_context * ctx);
-    LLAMA_API    struct llama_kv_cache * llama_get_kv_self (      struct llama_context * ctx);
+    LLAMA_API           llama_memory_t   llama_get_memory  (const struct llama_context * ctx);
     LLAMA_API  enum llama_pooling_type   llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
 
+    DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
+
     LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
     LLAMA_API enum llama_rope_type       llama_model_rope_type(const struct llama_model * model);
 
@@ -509,6 +514,13 @@ extern "C" {
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
 
+    // Returns the number of classifier outputs (only valid for classifier models)
+    // Undefined behavior for non-classifier models
+    LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model);
+
+    // Returns label of classifier output by index (<n_cls_out). Returns nullptr if no label provided
+    LLAMA_API const char * llama_model_cls_label(const struct llama_model * model, uint32_t i);
+
     LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
 
     LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
@@ -609,7 +621,81 @@ extern "C" {
                          int32_t   il_end);
 
     //
-    // KV cache
+    // Memory
+    //
+
+    // Clear the memory contents
+    // If data == true, the data buffers will also be cleared together with the metadata
+    LLAMA_API void llama_memory_clear(
+            llama_memory_t mem,
+                      bool data);
+
+    // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
+    // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
+    // seq_id < 0 : match any sequence
+    // p0 < 0     : [0,  p1]
+    // p1 < 0     : [p0, inf)
+    LLAMA_API bool llama_memory_seq_rm(
+            llama_memory_t mem,
+              llama_seq_id seq_id,
+                 llama_pos p0,
+                 llama_pos p1);
+
+    // Copy all tokens that belong to the specified sequence to another sequence
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
+    LLAMA_API void llama_memory_seq_cp(
+            llama_memory_t mem,
+              llama_seq_id seq_id_src,
+              llama_seq_id seq_id_dst,
+                 llama_pos p0,
+                 llama_pos p1);
+
+    // Removes all tokens that do not belong to the specified sequence
+    LLAMA_API void llama_memory_seq_keep(
+            llama_memory_t mem,
+              llama_seq_id seq_id);
+
+    // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
+    LLAMA_API void llama_memory_seq_add(
+            llama_memory_t mem,
+              llama_seq_id seq_id,
+                 llama_pos p0,
+                 llama_pos p1,
+                 llama_pos delta);
+
+    // Integer division of the positions by factor of `d > 1`
+    // p0 < 0 : [0,  p1]
+    // p1 < 0 : [p0, inf)
+    LLAMA_API void llama_memory_seq_div(
+            llama_memory_t mem,
+              llama_seq_id seq_id,
+                 llama_pos p0,
+                 llama_pos p1,
+                       int d);
+
+    // Returns the smallest position present in the memory for the specified sequence
+    // This is typically non-zero only for SWA caches
+    // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
+    // Return -1 if the sequence is empty
+    LLAMA_API llama_pos llama_memory_seq_pos_min(
+            llama_memory_t mem,
+              llama_seq_id seq_id);
+
+    // Returns the largest position present in the memory for the specified sequence
+    // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
+    // Return -1 if the sequence is empty
+    LLAMA_API llama_pos llama_memory_seq_pos_max(
+            llama_memory_t mem,
+              llama_seq_id seq_id);
+
+    // Check if the memory supports shifting
+    LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
+
+    //
+    // KV cache for self-attention (TODO: deprecate in favor of llama_memory)
     //
 
     // Returns the number of tokens in the KV cache (slow, use only for debug)
@@ -622,86 +708,95 @@ extern "C" {
                "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)");
 
     // Clear the KV cache - both cell info is erased and KV data is zeroed
-    LLAMA_API void llama_kv_self_clear(
-            struct llama_context * ctx);
+    DEPRECATED(LLAMA_API void llama_kv_self_clear(
+                struct llama_context * ctx),
+            "Use llama_memory_clear() instead");
 
     // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
     // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
     // seq_id < 0 : match any sequence
     // p0 < 0     : [0,  p1]
     // p1 < 0     : [p0, inf)
-    LLAMA_API bool llama_kv_self_seq_rm(
+    DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
-                       llama_pos   p1);
+                       llama_pos   p1),
+            "Use llama_memory_seq_rm() instead");
 
     // Copy all tokens that belong to the specified sequence to another sequence
     // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_self_seq_cp(
+    DEPRECATED(LLAMA_API void llama_kv_self_seq_cp(
             struct llama_context * ctx,
                     llama_seq_id   seq_id_src,
                     llama_seq_id   seq_id_dst,
                        llama_pos   p0,
-                       llama_pos   p1);
+                       llama_pos   p1),
+            "Use llama_memory_seq_cp() instead");
 
     // Removes all tokens that do not belong to the specified sequence
-    LLAMA_API void llama_kv_self_seq_keep(
+    DEPRECATED(LLAMA_API void llama_kv_self_seq_keep(
             struct llama_context * ctx,
-                    llama_seq_id   seq_id);
+                    llama_seq_id   seq_id),
+            "Use llama_memory_seq_keep() instead");
 
     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
     // If the KV cache is RoPEd, the KV data is updated accordingly:
     //   - lazily on next llama_decode()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_self_seq_add(
+    DEPRECATED(LLAMA_API void llama_kv_self_seq_add(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
                        llama_pos   p1,
-                       llama_pos   delta);
+                       llama_pos   delta),
+            "Use llama_memory_seq_add() instead");
 
     // Integer division of the positions by factor of `d > 1`
     // If the KV cache is RoPEd, the KV data is updated accordingly:
     //   - lazily on next llama_decode()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_self_seq_div(
+    DEPRECATED(void llama_kv_self_seq_div(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
                        llama_pos   p1,
-                             int   d);
+                             int   d),
+            "Use llama_memory_seq_div() instead");
 
     // Returns the smallest position present in the KV cache for the specified sequence
     // This is typically non-zero only for SWA caches
     // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
     // Return -1 if the sequence is empty
-    LLAMA_API llama_pos llama_kv_self_seq_pos_min(
+    DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min(
             struct llama_context * ctx,
-                    llama_seq_id   seq_id);
+                    llama_seq_id   seq_id),
+            "Use llama_memory_seq_pos_min() instead");
 
     // Returns the largest position present in the KV cache for the specified sequence
     // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
     // Return -1 if the sequence is empty
-    LLAMA_API llama_pos llama_kv_self_seq_pos_max(
+    DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max(
             struct llama_context * ctx,
-                    llama_seq_id   seq_id);
+                    llama_seq_id   seq_id),
+            "Use llama_memory_seq_pos_max() instead");
 
     // Defragment the KV cache
     // This will be applied:
     //   - lazily on next llama_decode()
-    LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
+    DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
             "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
 
     // Check if the context supports KV cache shifting
-    LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
+    DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx),
+            "use llama_memory_can_shift() instead");
 
     // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
-    LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
+    DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
             "simply remove this call, updates are applied lazily on the next llama_decode()");
 
     //
@@ -709,7 +804,7 @@ extern "C" {
     //
 
     // Returns the *actual* size in bytes of the state
-    // (logits, embedding and kv_cache)
+    // (logits, embedding and memory)
     // Only use when saving the state, not when restoring it, otherwise the size may be too small.
     LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
     LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@@ -765,12 +860,12 @@ extern "C" {
                           size_t   n_token_count),
         "use llama_state_save_file instead");
 
-    // Get the exact size needed to copy the KV cache of a single sequence
+    // Get the exact size needed to copy the state of a single sequence
     LLAMA_API size_t llama_state_seq_get_size(
             struct llama_context * ctx,
                     llama_seq_id   seq_id);
 
-    // Copy the KV cache of a single sequence into the specified buffer
+    // Copy the state of a single sequence into the specified buffer
     LLAMA_API size_t llama_state_seq_get_data(
             struct llama_context * ctx,
                          uint8_t * dst,
@@ -836,16 +931,16 @@ extern "C" {
     // For encode-decoder contexts, processes the batch using the encoder.
     // Can store the encoder output internally for later use by the decoder's cross-attention layers.
     //   0 - success
-    // < 0 - error. the KV cache state is restored to the state before this call
+    // < 0 - error. the memory state is restored to the state before this call
     LLAMA_API int32_t llama_encode(
             struct llama_context * ctx,
               struct llama_batch   batch);
 
     // Process a batch of tokens.
-    // Requires KV cache.
+    // Requires the context to have a memory.
     // For encode-decoder contexts, processes the batch using the decoder.
     // Positive return values does not mean a fatal error, but rather a warning.
-    // Upon non-zero return values, the KV cache state is restored to the state before this call
+    // Upon non-zero return values, the memory state is restored to the state before this call
     //    0 - success
     //    1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
     //    2 - aborted
@@ -916,7 +1011,7 @@ extern "C" {
 
     // Get the embeddings for a sequence id
     // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
-    // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence
+    // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence
     // otherwise: float[n_embd] (1-dimensional)
     LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);