]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Tue, 13 May 2025 10:20:19 +0000 (13:20 +0300)
committerGeorgi Gerganov <redacted>
Tue, 13 May 2025 10:59:21 +0000 (13:59 +0300)
ggml-ci

25 files changed:
examples/talk-llama/CMakeLists.txt
examples/talk-llama/llama-adapter.cpp
examples/talk-llama/llama-batch.cpp
examples/talk-llama/llama-batch.h
examples/talk-llama/llama-chat.cpp
examples/talk-llama/llama-chat.h
examples/talk-llama/llama-context.cpp
examples/talk-llama/llama-context.h
examples/talk-llama/llama-cparams.h
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-kv-cache.cpp
examples/talk-llama/llama-kv-cache.h
examples/talk-llama/llama-memory.h
examples/talk-llama/llama-model-loader.cpp
examples/talk-llama/llama-model-saver.cpp [new file with mode: 0644]
examples/talk-llama/llama-model-saver.h [new file with mode: 0644]
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama-model.h
examples/talk-llama/llama-quant.cpp
examples/talk-llama/llama-sampling.cpp
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama-vocab.h
examples/talk-llama/llama.cpp
examples/talk-llama/llama.h

index 3e3971a7ae1e3a2c292358c601af581268deb029..e060ba7bfc896cbd9d25b88c2b4d5ef804c299e4 100644 (file)
@@ -20,6 +20,7 @@ if (WHISPER_SDL2)
         llama-memory.cpp
         llama-mmap.cpp
         llama-model-loader.cpp
+        llama-model-saver.cpp
         llama-model.cpp
         llama-quant.cpp
         llama-sampling.cpp
index 7ac54d2391fd0c0ef75d623f523eb00e4318a902..8d94034aed95debd4b1ea9269996578744ba809b 100644 (file)
@@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
     std::vector<ggml_backend_buffer_type_t> buft_extra;
     {
         auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+        if (!cpu_dev) {
+            throw std::runtime_error(format("%s: no CPU backend found", __func__));
+        }
         auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
 
         auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
@@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
                 LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
 
                 auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+                if (!cpu_dev) {
+                    throw std::runtime_error(format("%s: no CPU backend found", __func__));
+                }
                 buft = ggml_backend_dev_buffer_type(cpu_dev);
 
                 break;
index 01d5ca57fd82bcbba57247774f2190dbccdb4cf6..a88b2fe3082c9447c1ed8cf66c90fb802855a959 100644 (file)
@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
     return ubatch;
 }
 
-void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
+llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
     GGML_ASSERT(batch.n_tokens >= 0);
     this->batch = &batch;
     this->n_embd = n_embd;
@@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
     for (size_t i = 0; i < n_tokens; ++i) {
         ids[i] = i;
     }
+
     if (simple_split) {
         seq.resize(1);
         llama_sbatch_seq & s = seq[0];
@@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
         s.length = n_tokens;
         return;
     }
+
     std::sort(ids.begin(), ids.end(),
             [&batch](size_t a, size_t b) {
                 int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
@@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
                 return n_seq_a > n_seq_b;
             }
     );
+
     // init seq
     llama_sbatch_seq * last_seq = nullptr;
 
@@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
         seq.push_back(new_seq);
         last_seq = &seq.back();
     }
+
     // keep shared prompts first at the end, then sort by length descending.
     std::sort(seq.begin(), seq.end(),
             [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
index f1df40d27086e7a47c56dd43bfabfdaf3fa0688b..6305051b62b794499a774f522169414d179e6989 100644 (file)
@@ -70,7 +70,8 @@ struct llama_sbatch {
     // sequence-wise split
     llama_ubatch split_seq(size_t n_ubatch);
 
-    void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
+    llama_sbatch() = default;
+    llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
 };
 
 // temporary allocate memory for the input batch if needed
index 735d2619c928fcfb7ba415ded0663318dba8eae1..d12743e6b9a0cf1bc12ebd698b71cef0a92db62b 100644 (file)
@@ -35,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "mistral-v3",        LLM_CHAT_TEMPLATE_MISTRAL_V3        },
     { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
     { "mistral-v7",        LLM_CHAT_TEMPLATE_MISTRAL_V7        },
+    { "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
     { "phi3",              LLM_CHAT_TEMPLATE_PHI_3             },
     { "phi4",              LLM_CHAT_TEMPLATE_PHI_4             },
     { "falcon3",           LLM_CHAT_TEMPLATE_FALCON_3          },
@@ -202,19 +203,20 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "<|im_start|>assistant\n";
         }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) {
         // Official mistral 'v7' template
         // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
+        //      https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
+        const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
         for (auto message : chat) {
             std::string role(message->role);
             std::string content(message->content);
             if (role == "system") {
-                ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
+                ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]";
             } else if (role == "user") {
-                ss << "[INST] " << content << "[/INST]";
-            }
-            else {
-                ss << " " << content << "</s>";
+                ss << "[INST]" << trailing_space << content << "[/INST]";
+            } else {
+                ss << trailing_space << content << "</s>";
             }
         }
     } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
@@ -447,8 +449,16 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "<|assistant|>";
         }
-    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) {
         ss << "[gMASK]" << "<sop>";
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|" << role << "|>" << "\n" << message->content;
+        }
+        if (add_ass) {
+            ss << "<|assistant|>\n";
+        }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
         for (auto message : chat) {
             std::string role(message->role);
             ss << "<|" << role << "|>" << "\n" << message->content;
index 3f5843466d044c4ceb7c95d9f7a4b81b8820f587..db24ade21e2ad76671732d3c6b625ea76357120e 100644 (file)
@@ -14,6 +14,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_MISTRAL_V3,
     LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
     LLM_CHAT_TEMPLATE_MISTRAL_V7,
+    LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
     LLM_CHAT_TEMPLATE_PHI_3,
     LLM_CHAT_TEMPLATE_PHI_4,
     LLM_CHAT_TEMPLATE_FALCON_3,
index 5a2eef9b784a12a71d49707a26c02b8f6c1f5b46..62246c10dab089f4583f0174492792eb27e7c543 100644 (file)
@@ -6,11 +6,9 @@
 #include "llama-model.h"
 #include "llama-kv-cache.h"
 
-#include <cassert>
 #include <cstring>
 #include <stdexcept>
 #include <cinttypes>
-#include <cmath>
 
 //
 // llama_context
@@ -95,6 +93,7 @@ llama_context::llama_context(
     }
 
     cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
+    cparams.op_offload = params.op_offload;
 
     const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
 
@@ -118,8 +117,6 @@ llama_context::llama_context(
                 __func__, n_ctx_per_seq, hparams.n_ctx_train);
     }
 
-    logits_all = params.logits_all;
-
     if (!hparams.vocab_only) {
         // GPU backends
         for (auto * dev : model.devices) {
@@ -177,44 +174,13 @@ llama_context::llama_context(
     }
 
     // init the memory module
-    // TODO: for now, always create a unified KV cache
     if (!hparams.vocab_only) {
-        kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
-
-        LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
-
-        cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
-
-        LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
-
-        uint32_t kv_size = cparams.n_ctx;
-        ggml_type type_k = params.type_k;
-        ggml_type type_v = params.type_v;
-
-        if (llama_model_is_recurrent(&model)) {
-            // Mamba needs at least as many KV cells as there are sequences kept at any time
-            kv_size = std::max((uint32_t) 1, params.n_seq_max);
-            // it's probably best to keep as much precision as possible for the states
-            type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
-            type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
-        }
-
-        GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
-        GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
-
-        if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
-            throw std::runtime_error("failed to initialize self-attention cache");
-        }
+        llama_memory_params params_mem = {
+            /*.type_k =*/ params.type_k,
+            /*.type_v =*/ params.type_v,
+        };
 
-        {
-            const size_t memory_size_k = kv_self->size_k_bytes();
-            const size_t memory_size_v = kv_self->size_v_bytes();
-
-            LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                    (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
-                    ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
-                    ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
-        }
+        memory.reset(model.create_memory(params_mem, cparams));
     }
 
     // init backends
@@ -278,7 +244,7 @@ llama_context::llama_context(
             }
         }
 
-        sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
+        sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
 
         if (pipeline_parallel) {
             LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
@@ -286,7 +252,7 @@ llama_context::llama_context(
     }
 
     // reserve worst-case graph
-    if (!hparams.vocab_only) {
+    if (!hparams.vocab_only && memory) {
         const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
@@ -305,7 +271,9 @@ llama_context::llama_context(
         int n_nodes_tg  = -1;
 
         // simulate full KV cache
-        kv_self->n = kv_self->size;
+        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+
+        kv_self->set_full();
 
         cross.v_embd.clear();
 
@@ -391,7 +359,9 @@ llama_context::llama_context(
     }
 }
 
-llama_context::~llama_context() = default;
+llama_context::~llama_context() {
+    ggml_opt_free(opt_ctx);
+}
 
 void llama_context::synchronize() {
     ggml_backend_sched_synchronize(sched.get());
@@ -427,6 +397,18 @@ const llama_model & llama_context::get_model() const {
     return model;
 }
 
+const llama_cparams & llama_context::get_cparams() const {
+    return cparams;
+}
+
+ggml_backend_sched_t llama_context::get_sched() const {
+    return sched.get();
+}
+
+ggml_context * llama_context::get_ctx_compute() const {
+    return ctx_compute.get();
+}
+
 uint32_t llama_context::n_ctx() const {
     return cparams.n_ctx;
 }
@@ -456,337 +438,21 @@ uint32_t llama_context::n_threads_batch() const {
 }
 
 llama_kv_cache * llama_context::get_kv_self() {
-    return kv_self.get();
+    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+    return kv_self;
 }
 
 const llama_kv_cache * llama_context::get_kv_self() const {
-    return kv_self.get();
-}
-
-ggml_tensor * llama_context::build_rope_shift(
-        ggml_context * ctx0,
-        ggml_tensor * cur,
-        ggml_tensor * shift,
-        ggml_tensor * factors,
-              float   freq_base,
-              float   freq_scale) const {
-    const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
-
-    const auto & yarn_ext_factor  = cparams.yarn_ext_factor;
-    const auto & yarn_beta_fast   = cparams.yarn_beta_fast;
-    const auto & yarn_beta_slow   = cparams.yarn_beta_slow;
-
-    const auto & hparams = model.hparams;
-
-    const auto & n_rot     = hparams.n_rot;
-    const auto & rope_type = hparams.rope_type;
-
-    // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
-    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
-
-    ggml_tensor * tmp;
-
-    if (ggml_is_quantized(cur->type)) {
-        // dequantize to f32 -> RoPE -> quantize back
-        tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
-
-        tmp = ggml_rope_ext(ctx0, tmp,
-                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
-
-        tmp = ggml_cpy(ctx0, tmp, cur);
-    } else {
-        // we rotate only the first n_rot dimensions
-        tmp = ggml_rope_ext_inplace(ctx0, cur,
-                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
-    }
-
-    return tmp;
-}
-
-class llm_graph_input_k_shift : public llm_graph_input_i {
-public:
-    llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
-    virtual ~llm_graph_input_k_shift() = default;
-
-    void set_input(const llama_ubatch * ubatch) override;
-
-    ggml_tensor * k_shift; // I32 [kv_size]
-
-    const llama_kv_cache_unified * kv_self;
-};
-
-void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
-    GGML_UNUSED(ubatch);
-
-    if (k_shift) {
-        assert(ggml_backend_buffer_is_host(k_shift->buffer));
-
-        int32_t * data = (int32_t *) k_shift->data;
-
-        for (uint32_t i = 0; i < kv_self->size; ++i) {
-            data[i] = kv_self->cells[i].delta;
-        }
-    }
-}
-
-llm_graph_result_ptr llama_context::build_kv_self_shift(
-        ggml_context * ctx0,
-        ggml_cgraph * gf) const {
-    auto res = std::make_unique<llm_graph_result>();
-
-    const auto & hparams = model.hparams;
-
-    const auto & n_layer = hparams.n_layer;
-
-    const auto & n_embd_head_k = hparams.n_embd_head_k;
-  //const auto & n_embd_head_v = hparams.n_embd_head_v;
-
-    //GGML_ASSERT(kv_self->size == n_ctx);
-
-    auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
-
-    inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
-    ggml_set_input(inp->k_shift);
-
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        const int64_t n_head_kv    = hparams.n_head_kv(il);
-        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-
-        const bool is_swa = hparams.is_swa(il);
-
-        // note: the swa rope params could become part of the cparams in the future
-        //       if we decide to make them configurable, like the non-sliding ones
-        const float freq_base_l  = is_swa ? hparams.rope_freq_base_train_swa  : cparams.rope_freq_base;
-        const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
-
-        ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
-
-        ggml_tensor * k =
-            ggml_view_3d(ctx0, kv_self->k_l[il],
-                n_embd_head_k, n_head_kv, kv_self->size,
-                ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
-                ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
-                0);
-
-        ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
-
-        ggml_build_forward_expand(gf, cur);
-    }
-
-    res->add_input(std::move(inp));
-
-    return res;
-}
-
-llm_graph_result_ptr llama_context::build_kv_self_defrag(
-        ggml_context * ctx0,
-        ggml_cgraph * gf) const {
-    auto res = std::make_unique<llm_graph_result>();
-
-    const auto & hparams = model.hparams;
-
-    const auto & ids = kv_self->defrag_info.ids;
-
-#if 0
-    // CPU defrag
-    //
-    // TODO: optimizations are possible:
-    //       - multiple threads
-    //       - avoid copying to the host memory when already there
-    //
-    // likely not worth the effort, as we have ggml_graph based defrag
-    //
-
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-
-    const uint32_t kv_size = size;
-
-    std::vector<uint8_t> buf_k;
-    std::vector<uint8_t> buf_v;
-
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
-        const size_t k_size     = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
-
-        const size_t v_size_el = ggml_type_size(v_l[il]->type);
-        const size_t v_size    = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
-
-        buf_k.resize(k_size);
-        buf_v.resize(v_size);
-
-        ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
-
-        // batch move [i, i+nm) to [id, id+nm)
-        // note: cells can move only to a lower index
-        for (uint32_t i = 0; i < n_kv; ++i) {
-            const uint32_t id = ids[i];
-
-            if (i == id || id == n_kv) {
-                continue;
-            }
-
-            uint32_t nm = 1;
-
-            while (i + nm < n_kv && ids[i + nm] == id + nm) {
-                nm++;
-            }
-
-            // move keys
-            {
-                const int64_t os =  i*k_size_row;
-                const int64_t od = id*k_size_row;
-
-                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
-            }
-
-            // move values (note: they are transposed)
-            {
-                const int64_t os =  i;
-                const int64_t od = id;
-
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
-                }
-            }
-
-            i += nm - 1;
-        }
-
-        ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
-        ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
-    }
-#else
-    for (uint32_t i = 0; i < ids.size(); ++i) {
-        const uint32_t id = ids[i];
-
-        if (i == id || id == ids.size()) {
-            continue;
-        }
-
-        uint32_t nm = 1;
-
-        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
-            nm++;
-        }
-
-        for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
-            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-            const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-            ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
-                    n_embd_k_gqa, nm,
-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
-
-            ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
-                    n_embd_k_gqa, nm,
-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
-                    ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
-
-            ggml_tensor * view_v_src;
-            ggml_tensor * view_v_dst;
-
-            if (cparams.flash_attn) {
-                // NOTE: the V cache is not transposed when using flash attention
-                view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
-                        n_embd_v_gqa, nm,
-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
-
-                view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
-                        n_embd_v_gqa, nm,
-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
-                        ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
-            } else {
-                view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
-                        ggml_row_size(kv_self->v_l[il]->type, i));
-
-                view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
-                        nm, n_embd_v_gqa,
-                        ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
-                        ggml_row_size(kv_self->v_l[il]->type, id));
-            }
-
-            ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
-            ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
-        }
-
-        i += nm - 1;
-    }
-
-    //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
-#endif
-
-    return res;
+    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+    return kv_self;
 }
 
 void llama_context::kv_self_update() {
-    auto & kv = kv_self;
-
     bool need_reserve = false;
 
-    if (kv->has_shift) {
-        if (!kv->get_can_shift()) {
-            GGML_ABORT("The current context does not support K-shift");
-        }
-
-        LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
-
-        // apply K-shift if needed
-        if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
-            ggml_backend_sched_reset(sched.get());
-
-            auto * gf = graph_init();
-
-            auto res = build_kv_self_shift(ctx_compute.get(), gf);
-
-            ggml_backend_sched_alloc_graph(sched.get(), gf);
-
-            res->set_inputs(nullptr);
-
-            graph_compute(gf, false);
+    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
-            need_reserve = true;
-        }
-
-        {
-            kv->has_shift = false;
-
-            for (uint32_t i = 0; i < kv->size; ++i) {
-                kv->cells[i].delta = 0;
-            }
-        }
-    }
-
-    // defragment the KV cache if needed
-    if (kv->do_defrag) {
-        LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
-
-        if (kv->defrag_prepare(graph_max_nodes())) {
-            ggml_backend_sched_reset(sched.get());
-
-            auto * gf = graph_init();
-
-            auto res = build_kv_self_defrag(ctx_compute.get(), gf);
-
-            ggml_backend_sched_alloc_graph(sched.get(), gf);
-
-            res->set_inputs(nullptr);
-
-            graph_compute(gf, false);
-
-            need_reserve = true;
-        }
-
-        kv->do_defrag = false;
-    }
+    need_reserve = kv_self->update(*this);
 
     // reserve a worst case graph if needed
     if (need_reserve) {
@@ -797,7 +463,7 @@ void llama_context::kv_self_update() {
         uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
 
         // simulate full KV cache
-        kv_self->n = kv_self->size;
+        kv_self->set_full();
 
         llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
         llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@@ -818,9 +484,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
 }
 
 float * llama_context::get_logits() {
-    // reorder logits for backward compatibility
-    output_reorder();
-
     return logits;
 }
 
@@ -863,9 +526,6 @@ float * llama_context::get_logits_ith(int32_t i) {
 }
 
 float * llama_context::get_embeddings() {
-    // reorder embeddings for backward compatibility
-    output_reorder();
-
     return embd;
 }
 
@@ -1017,8 +677,8 @@ int llama_context::encode(llama_batch & inp_batch) {
     }
 
     // temporary allocate memory for the input batch if needed
-    // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
+    // note: during encode, we always pass the full sequence starting from pos = 0
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
 
     const llama_batch & batch = batch_allocr.batch;
     const int32_t n_tokens = batch.n_tokens;
@@ -1043,11 +703,13 @@ int llama_context::encode(llama_batch & inp_batch) {
         t_compute_start_us = ggml_time_us();
     }
 
+    embd_seq.clear();
+
     n_queued_tokens += n_tokens;
 
     const int64_t n_embd = hparams.n_embd;
 
-    sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
+    llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
 
     const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
 
@@ -1104,12 +766,12 @@ int llama_context::encode(llama_batch & inp_batch) {
         ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
         GGML_ASSERT(backend_embd != nullptr);
 
-        GGML_ASSERT(embd != nullptr);
-
         switch (cparams.pooling_type) {
             case LLAMA_POOLING_TYPE_NONE:
                 {
                     // extract token embeddings
+                    GGML_ASSERT(embd != nullptr);
+
                     GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
                     ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
                 } break;
@@ -1134,11 +796,18 @@ int llama_context::encode(llama_batch & inp_batch) {
                 } break;
             case LLAMA_POOLING_TYPE_RANK:
                 {
-                    // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
-                    //       wait for an encoder model that requires this pooling type in order to test it
-                    //       https://github.com/ggerganov/llama.cpp/pull/9510
-                    GGML_ABORT("RANK pooling not implemented yet");
-                }
+                    // extract the rerank score - a single float per sequence
+                    auto & embd_seq_out = embd_seq;
+
+                    for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+                        const llama_seq_id seq_id = ubatch.seq_id[s][0];
+                        if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                            continue;
+                        }
+                        embd_seq_out[seq_id].resize(1);
+                        ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+                    }
+                } break;
             case LLAMA_POOLING_TYPE_UNSPECIFIED:
                 {
                     GGML_ABORT("unknown pooling type");
@@ -1176,14 +845,21 @@ int llama_context::encode(llama_batch & inp_batch) {
 }
 
 int llama_context::decode(llama_batch & inp_batch) {
+    if (!memory) {
+        LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
+        return encode(inp_batch);
+    }
+
     if (inp_batch.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
+    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+
     // temporary allocate memory for the input batch if needed
-    // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
+    // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
 
     const llama_batch & batch = batch_allocr.batch;
 
@@ -1195,7 +871,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     const int64_t n_tokens_all = batch.n_tokens;
     const int64_t n_embd       = hparams.n_embd;
 
-    llama_kv_cache_guard kv_guard(kv_self.get());
+    llama_kv_cache_guard kv_guard(kv_self);
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
@@ -1229,18 +905,14 @@ int llama_context::decode(llama_batch & inp_batch) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
             n_outputs_all += batch.logits[i] != 0;
         }
-    } else if (logits_all || embd_pooled) {
+    } else if (embd_pooled) {
         n_outputs_all = n_tokens_all;
     } else {
         // keep last output only
         n_outputs_all = 1;
     }
 
-    const bool logits_all = n_outputs_all == n_tokens_all;
-
-    sbatch.from_batch(batch, n_embd,
-            /* simple_split */ !kv_self->recurrent,
-            /* logits_all   */ logits_all);
+    llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
 
     // reserve output buffer
     if (output_reserve(n_outputs_all) < n_outputs_all) {
@@ -1254,22 +926,7 @@ int llama_context::decode(llama_batch & inp_batch) {
     int64_t n_outputs_prev = 0;
 
     while (sbatch.n_tokens > 0) {
-        llama_ubatch ubatch = llama_ubatch();
-
-        const auto & n_ubatch = cparams.n_ubatch;
-
-        if (kv_self->recurrent) {
-            if (embd_pooled) {
-                // Pooled embeddings cannot be split across ubatches (yet)
-                ubatch = sbatch.split_seq(cparams.n_ubatch);
-            } else {
-                // recurrent model architectures are easier to implement
-                // with equal-length sequences
-                ubatch = sbatch.split_equal(cparams.n_ubatch);
-            }
-        } else {
-            ubatch = sbatch.split_simple(n_ubatch);
-        }
+        llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
 
         // count the outputs in this u_batch
         {
@@ -1289,24 +946,12 @@ int llama_context::decode(llama_batch & inp_batch) {
         }
 
         // find KV slot
-        {
-            if (!kv_self->find_slot(ubatch)) {
-                LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
-                return 1;
-            }
+        if (!kv_self->find_slot(ubatch)) {
+            LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
 
-            if (!kv_self->recurrent) {
-                // a heuristic, to avoid attending the full cache if it is not yet utilized
-                // after enough generations, the benefit from this heuristic disappears
-                // if we start defragmenting the cache, the benefit from this will be more important
-                const uint32_t pad = kv_self->get_padding(cparams);
-                kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
-            }
+            return 1;
         }
 
-        //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
-
         ggml_backend_sched_reset(sched.get());
         ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
 
@@ -1420,43 +1065,68 @@ int llama_context::decode(llama_batch & inp_batch) {
     // finalize the batch processing
     kv_guard.commit();
 
+    // set to total number of outputs in the batch, for use in llama_get_logits_ith
+    n_outputs = n_outputs_all;
+
     // set output mappings
     {
         bool sorted_output = true;
 
-        GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
+        auto & out_ids = sbatch.out_ids;
+
+        GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
 
         for (int64_t i = 0; i < n_outputs_all; ++i) {
-            int64_t out_id = sbatch.out_ids[i];
+            int64_t out_id = out_ids[i];
             output_ids[out_id] = i;
             if (out_id != i) {
                 sorted_output = false;
             }
         }
 
-        if (sorted_output) {
-            sbatch.out_ids.clear();
+        // make the outputs have the same order they had in the user-provided batch
+        // note: this is mostly relevant for recurrent models atm
+        if (!sorted_output) {
+            const uint32_t n_vocab = model.vocab.n_tokens();
+            const uint32_t n_embd  = model.hparams.n_embd;
+
+            GGML_ASSERT((size_t) n_outputs == out_ids.size());
+
+            // TODO: is there something more efficient which also minimizes swaps?
+            // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
+            for (int32_t i = 0; i < n_outputs - 1; ++i) {
+                int32_t j_min = i;
+                for (int32_t j = i + 1; j < n_outputs; ++j) {
+                    if (out_ids[j] < out_ids[j_min]) {
+                        j_min = j;
+                    }
+                }
+                if (j_min == i) { continue; }
+                std::swap(out_ids[i], out_ids[j_min]);
+                if (logits_size > 0) {
+                    for (uint32_t k = 0; k < n_vocab; k++) {
+                        std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
+                    }
+                }
+                if (embd_size > 0) {
+                    for (uint32_t k = 0; k < n_embd; k++) {
+                        std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
+                    }
+                }
+            }
+            std::fill(output_ids.begin(), output_ids.end(), -1);
+            for (int32_t i = 0; i < n_outputs; ++i) {
+                output_ids[out_ids[i]] = i;
+            }
         }
     }
 
-    // set to total number of outputs in the batch, for use in llama_get_logits_ith
-    n_outputs = n_outputs_all;
-
     // wait for the computation to finish (automatically done when obtaining the model output)
     //synchronize();
 
     // decide if we need to defrag the kv cache
-    if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
-        // - do not defrag small contexts (i.e. < 2048 tokens)
-        // - count the padding towards the number of used tokens
-        const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
-
-        // queue defragmentation for next llama_kv_cache_update
-        if (fragmentation > cparams.defrag_thold) {
-            LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
-
-            kv_self->defrag();
-        }
+    if (cparams.defrag_thold > 0.0f) {
+        kv_self->defrag_sched(cparams.defrag_thold);
     }
 
     // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@@ -1542,44 +1212,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
     return n_outputs_max;
 }
 
-void llama_context::output_reorder() {
-    auto & out_ids = sbatch.out_ids;
-    if (!out_ids.empty()) {
-        const uint32_t n_vocab = model.vocab.n_tokens();
-        const uint32_t n_embd  = model.hparams.n_embd;
-
-        GGML_ASSERT((size_t) n_outputs == out_ids.size());
-
-        // TODO: is there something more efficient which also minimizes swaps?
-        // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
-        for (int32_t i = 0; i < n_outputs - 1; ++i) {
-            int32_t j_min = i;
-            for (int32_t j = i + 1; j < n_outputs; ++j) {
-                if (out_ids[j] < out_ids[j_min]) {
-                    j_min = j;
-                }
-            }
-            if (j_min == i) { continue; }
-            std::swap(out_ids[i], out_ids[j_min]);
-            if (logits_size > 0) {
-                for (uint32_t k = 0; k < n_vocab; k++) {
-                    std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
-                }
-            }
-            if (embd_size > 0) {
-                for (uint32_t k = 0; k < n_embd; k++) {
-                    std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
-                }
-            }
-        }
-        std::fill(output_ids.begin(), output_ids.end(), -1);
-        for (int32_t i = 0; i < n_outputs; ++i) {
-            output_ids[out_ids[i]] = i;
-        }
-        out_ids.clear();
-    }
-}
-
 //
 // graph
 //
@@ -1616,7 +1248,7 @@ llm_graph_result_ptr llama_context::graph_build(
                 /*.backend_cpu =*/ backend_cpu,
                 /*.cvec        =*/ &cvec,
                 /*.loras       =*/ &loras,
-                /*.memory      =*/ kv_self.get(),
+                /*.memory      =*/ memory.get(),
                 /*.cross       =*/ &cross,
                 /*.n_outputs   =*/ n_outputs,
                 /*.cb          =*/ graph_get_cb(),
@@ -2020,8 +1652,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
     {
         LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
 
-        output_reorder();
-
         const auto n_outputs    = this->n_outputs;
         const auto & output_ids = this->output_ids;
 
@@ -2075,6 +1705,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
     }
 
     LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
+    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+
     kv_self->state_write(io);
 
     return io.n_bytes();
@@ -2158,8 +1790,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
         }
     }
 
-    LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
-    kv_self->state_read(io);
+    if (memory) {
+        LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
+
+        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+
+        kv_self->state_read(io);
+    }
 
     return io.n_bytes();
 }
@@ -2167,7 +1804,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
 size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
     GGML_UNUSED(seq_id);
 
-    kv_self->state_write(io, seq_id);
+    if (memory) {
+        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+
+        kv_self->state_write(io, seq_id);
+    }
 
     return io.n_bytes();
 }
@@ -2175,7 +1816,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
 size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
     GGML_UNUSED(seq_id);
 
-    kv_self->state_read(io, seq_id);
+    if (memory) {
+        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+
+        kv_self->state_read(io, seq_id);
+    }
 
     return io.n_bytes();
 }
@@ -2203,6 +1848,215 @@ void llama_context::perf_reset() {
     t_p_eval_us = n_p_eval = 0;
 }
 
+//
+// training
+//
+
+static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
+    if (!tensor || tensor->type != GGML_TYPE_F32) {
+        return;
+    }
+    if (!param_filter(tensor, userdata)) {
+        return;
+    }
+    if (strcmp(tensor->name, "token_embd.weight") == 0) {
+        return; // FIXME
+    }
+    if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
+        return; // FIXME
+    }
+    ggml_set_param(tensor);
+}
+
+void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
+    GGML_ASSERT(!opt_ctx);
+    model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
+    const uint32_t n_batch     = std::min(this->n_batch(),  model->hparams.n_ctx_train);
+    const uint32_t n_ubatch    = std::min(this->n_ubatch(), n_batch);
+    GGML_ASSERT(model->hparams.n_ctx_train % n_batch  == 0);
+    GGML_ASSERT(n_batch                    % n_ubatch == 0);
+
+    ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
+    opt_params.opt_period      = n_batch / n_ubatch;
+    opt_params.get_opt_pars    = lopt_params.get_opt_pars;
+    opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
+
+    opt_ctx = ggml_opt_init(opt_params);
+
+    llama_opt_param_filter param_filter = lopt_params.param_filter;
+    void * param_filter_ud              = lopt_params.param_filter_ud;
+
+  //llama_set_param(model->tok_embd,        param_filter, param_filter_ud); // FIXME
+    llama_set_param(model->type_embd,       param_filter, param_filter_ud);
+    llama_set_param(model->pos_embd,        param_filter, param_filter_ud);
+    llama_set_param(model->tok_norm,        param_filter, param_filter_ud);
+    llama_set_param(model->tok_norm_b,      param_filter, param_filter_ud);
+    llama_set_param(model->output_norm,     param_filter, param_filter_ud);
+    llama_set_param(model->output_norm_b,   param_filter, param_filter_ud);
+    llama_set_param(model->output,          param_filter, param_filter_ud);
+    llama_set_param(model->output_b,        param_filter, param_filter_ud);
+    llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
+    llama_set_param(model->cls,             param_filter, param_filter_ud);
+    llama_set_param(model->cls_b,           param_filter, param_filter_ud);
+    llama_set_param(model->cls_out,         param_filter, param_filter_ud);
+    llama_set_param(model->cls_out_b,       param_filter, param_filter_ud);
+
+    for (struct llama_layer & layer : model->layers) {
+        for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
+            llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
+        }
+    }
+}
+
+void llama_context::opt_epoch_iter(
+        ggml_opt_dataset_t               dataset,
+        ggml_opt_result_t                result,
+        const std::vector<llama_token> & tokens,
+        const std::vector<llama_token> & labels_sparse,
+        llama_batch                    & batch,
+        ggml_opt_epoch_callback          callback,
+        bool                             train,
+        int64_t                          idata_in_loop,
+        int64_t                          ndata_in_loop,
+        int64_t                          t_loop_start) {
+    GGML_ASSERT(opt_ctx);
+    const uint32_t n_ctx    = llama_model_n_ctx_train(&model);
+    const uint32_t n_batch  = std::min(this->n_batch(),  n_ctx);
+    const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
+
+    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+
+    kv_self->clear();
+    llama_kv_cache_guard kv_guard(kv_self);
+
+    for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
+        batch.n_tokens = n_batch;
+        for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
+            batch.token   [pos_batch]    = tokens[pos_ctx + pos_batch];
+            batch.pos     [pos_batch]    = pos_ctx + pos_batch;
+            batch.n_seq_id[pos_batch]    = 1;
+            batch.seq_id  [pos_batch][0] = 0;
+            batch.logits  [pos_batch]    = true;
+        }
+
+        const auto n_tokens_all = batch.n_tokens;
+
+        n_queued_tokens += n_tokens_all;
+
+        // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+        const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+        embd_seq.clear();
+
+        int64_t n_outputs_all = n_tokens_all;
+
+        llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
+
+        // reserve output buffer
+        if (output_reserve(n_outputs_all) < n_outputs_all) {
+            LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
+            GGML_ABORT("TODO: handle this error");
+        };
+
+        for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
+            llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
+
+            n_outputs = ubatch.n_tokens;
+
+            // TODO: not sure if this is needed
+            if (!kv_self->find_slot(ubatch)) {
+                LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
+
+                GGML_ABORT("TODO: handle this error");
+            }
+
+            auto * gf = graph_init();
+            auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
+
+            struct ggml_context * ctx_compute_opt;
+            {
+                const size_t size_gf = ggml_graph_size(gf);
+                const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
+                struct ggml_init_params params = {
+                    /*.mem_size   =*/ size_meta,
+                    /*.mem_buffer =*/ nullptr,
+                    /*.no_alloc   =*/ true,
+                };
+                ctx_compute_opt = ggml_init(params);
+            }
+            ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
+            ggml_opt_alloc(opt_ctx, train);
+            res->set_inputs(&ubatch);
+            {
+                struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
+                GGML_ASSERT(labels->ne[1] == n_ubatch);
+                ggml_set_zero(labels);
+                const float onef = 1.0f;
+                for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
+                    const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
+                    GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
+                    ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
+                }
+            }
+            ggml_opt_eval(opt_ctx, result);
+            if (callback) {
+                callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
+            }
+            ggml_free(ctx_compute_opt);
+        }
+    }
+
+    kv_guard.commit();
+}
+
+void llama_context::opt_epoch(
+        ggml_opt_dataset_t        dataset,
+        ggml_opt_result_t         result_train,
+        ggml_opt_result_t         result_eval,
+        int64_t                   idata_split,
+        ggml_opt_epoch_callback   callback_train,
+        ggml_opt_epoch_callback   callback_eval) {
+    const uint32_t n_ctx    = this->n_ctx();
+    const uint32_t n_batch  = std::min(cparams.n_batch,  n_ctx);
+    const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
+    const  int64_t ndata    = ggml_opt_dataset_ndata(dataset);
+
+    GGML_ASSERT(idata_split >= 0);
+    GGML_ASSERT(idata_split <= ndata);
+
+    const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
+
+    struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
+    std::vector<llama_token>        tokens(n_ctx);
+    std::vector<llama_token> labels_sparse(n_ctx);
+
+    int64_t idata = 0;
+
+    int64_t t_loop_start = ggml_time_us();
+    int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
+    for (; idata < idata_split; ++idata) {
+        constexpr bool train = true;
+        const int64_t idata_in_loop = idata*ubatch_per_ctx;
+
+        ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
+        opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
+            callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
+    }
+
+    t_loop_start = ggml_time_us();
+    ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
+    for (; idata < ndata; ++idata) {
+        constexpr bool train = false;
+        const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
+
+        ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
+        opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
+            callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
+    }
+
+    llama_batch_free(batch);
+}
+
 //
 // interface implementation
 //
@@ -2230,13 +2084,13 @@ llama_context_params llama_context_default_params() {
         /*.cb_eval_user_data           =*/ nullptr,
         /*.type_k                      =*/ GGML_TYPE_F16,
         /*.type_v                      =*/ GGML_TYPE_F16,
-        /*.logits_all                  =*/ false,
+        /*.abort_callback              =*/ nullptr,
+        /*.abort_callback_data         =*/ nullptr,
         /*.embeddings                  =*/ false,
         /*.offload_kqv                 =*/ true,
         /*.flash_attn                  =*/ false,
         /*.no_perf                     =*/ true,
-        /*.abort_callback              =*/ nullptr,
-        /*.abort_callback_data         =*/ nullptr,
+        /*.op_offload                  =*/ true,
     };
 
     return result;
@@ -2530,7 +2384,7 @@ void llama_kv_cache_seq_cp(
          llama_seq_id   seq_id_dst,
             llama_pos   p0,
             llama_pos   p1) {
-    return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
+    llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
 }
 
 void llama_kv_self_seq_cp(
@@ -2544,14 +2398,14 @@ void llama_kv_self_seq_cp(
         return;
     }
 
-    return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+    kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
 }
 
 // deprecated
 void llama_kv_cache_seq_keep(
         llama_context * ctx,
          llama_seq_id   seq_id) {
-    return llama_kv_self_seq_keep(ctx, seq_id);
+    llama_kv_self_seq_keep(ctx, seq_id);
 }
 
 void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
@@ -2560,7 +2414,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
         return;
     }
 
-    return kv->seq_keep(seq_id);
+    kv->seq_keep(seq_id);
 }
 
 // deprecated
@@ -2570,7 +2424,7 @@ void llama_kv_cache_seq_add(
             llama_pos   p0,
             llama_pos   p1,
             llama_pos   delta) {
-    return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
+    llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
 }
 
 void llama_kv_self_seq_add(
@@ -2584,7 +2438,7 @@ void llama_kv_self_seq_add(
         return;
     }
 
-    return kv->seq_add(seq_id, p0, p1, delta);
+    kv->seq_add(seq_id, p0, p1, delta);
 }
 
 // deprecated
@@ -2594,7 +2448,7 @@ void llama_kv_cache_seq_div(
             llama_pos   p0,
             llama_pos   p1,
                   int   d) {
-    return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
+    llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
 }
 
 void llama_kv_self_seq_div(
@@ -2608,7 +2462,7 @@ void llama_kv_self_seq_div(
         return;
     }
 
-    return kv->seq_div(seq_id, p0, p1, d);
+    kv->seq_div(seq_id, p0, p1, d);
 }
 
 // deprecated
@@ -2627,7 +2481,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
 
 // deprecated
 void llama_kv_cache_defrag(llama_context * ctx) {
-    return llama_kv_self_defrag(ctx);
+    llama_kv_self_defrag(ctx);
 }
 
 void llama_kv_self_defrag(llama_context * ctx) {
@@ -2636,7 +2490,8 @@ void llama_kv_self_defrag(llama_context * ctx) {
         return;
     }
 
-    return kv->defrag();
+    // force defrag
+    kv->defrag_sched(-1.0f);
 }
 
 // deprecated
@@ -2820,3 +2675,34 @@ void llama_perf_context_print(const llama_context * ctx) {
 void llama_perf_context_reset(llama_context * ctx) {
     ctx->perf_reset();
 }
+
+//
+// training
+//
+
+bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
+    GGML_UNUSED(tensor);
+    GGML_UNUSED(userdata);
+    return true;
+}
+
+void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
+    ctx->opt_init(model, lopt_params);
+}
+
+void llama_opt_epoch(
+        struct llama_context    * ctx,
+        ggml_opt_dataset_t        dataset,
+        ggml_opt_result_t         result_train,
+        ggml_opt_result_t         result_eval,
+        int64_t                   idata_split,
+        ggml_opt_epoch_callback   callback_train,
+        ggml_opt_epoch_callback   callback_eval) {
+    ctx->opt_epoch(
+        dataset,
+        result_train,
+        result_eval,
+        idata_split,
+        callback_train,
+        callback_eval);
+}
index 5457f077c15bfedb2e4430044fb977415b1ee4e4..c0ceacb10ce6f56986cfa634a5700b664957a80b 100644 (file)
@@ -7,6 +7,7 @@
 #include "llama-adapter.h"
 
 #include "ggml-cpp.h"
+#include "ggml-opt.h"
 
 #include <map>
 #include <vector>
@@ -27,7 +28,12 @@ struct llama_context {
 
     void synchronize();
 
-    const llama_model & get_model() const;
+    const llama_model   & get_model()   const;
+    const llama_cparams & get_cparams() const;
+
+    ggml_backend_sched_t get_sched() const;
+
+    ggml_context * get_ctx_compute() const;
 
     uint32_t n_ctx()         const;
     uint32_t n_ctx_per_seq() const;
@@ -128,6 +134,32 @@ struct llama_context {
     llama_perf_context_data perf_get_data() const;
     void perf_reset();
 
+    //
+    // training
+    //
+
+    void opt_init(struct llama_model * model, struct llama_opt_params lopt_params);
+
+    void opt_epoch(
+            ggml_opt_dataset_t      dataset,
+            ggml_opt_result_t       result_train,
+            ggml_opt_result_t       result_eval,
+            int64_t                 idata_split,
+            ggml_opt_epoch_callback callback_train,
+            ggml_opt_epoch_callback callback_eval);
+
+    void opt_epoch_iter(
+            ggml_opt_dataset_t               dataset,
+            ggml_opt_result_t                result,
+            const std::vector<llama_token> & tokens,
+            const std::vector<llama_token> & labels_sparse,
+            llama_batch                    & batch,
+            ggml_opt_epoch_callback          callback,
+            bool                             train,
+            int64_t                          idata_in_loop,
+            int64_t                          ndata_in_loop,
+            int64_t                          t_loop_start);
+
 private:
     //
     // output
@@ -137,49 +169,30 @@ private:
     // Returns max number of outputs for which space was reserved.
     int32_t output_reserve(int32_t n_outputs);
 
-    // make the outputs have the same order they had in the user-provided batch
-    // TODO: maybe remove this
-    void output_reorder();
-
     //
     // graph
     //
 
+public:
     int32_t graph_max_nodes() const;
 
     // zero-out inputs and create the ctx_compute for the compute graph
     ggml_cgraph * graph_init();
 
+    // returns the result of ggml_backend_sched_graph_compute_async execution
+    ggml_status graph_compute(
+            ggml_cgraph * gf,
+                   bool   batched);
+
+private:
     llm_graph_result_ptr graph_build(
             ggml_context * ctx,
              ggml_cgraph * gf,
       const llama_ubatch & ubatch,
           llm_graph_type   gtype);
 
-    // returns the result of ggml_backend_sched_graph_compute_async execution
-    ggml_status graph_compute(
-            ggml_cgraph * gf,
-                   bool   batched);
-
     llm_graph_cb graph_get_cb() const;
 
-    // used by kv_self_update()
-    ggml_tensor * build_rope_shift(
-        ggml_context * ctx0,
-        ggml_tensor * cur,
-        ggml_tensor * shift,
-        ggml_tensor * factors,
-              float   freq_base,
-              float   freq_scale) const;
-
-    llm_graph_result_ptr build_kv_self_shift(
-            ggml_context * ctx0,
-            ggml_cgraph * gf) const;
-
-    llm_graph_result_ptr build_kv_self_defrag(
-            ggml_context * ctx0,
-            ggml_cgraph * gf) const;
-
     // TODO: read/write lora adapters and cvec
     size_t state_write_data(llama_io_write_i & io);
     size_t state_read_data (llama_io_read_i  & io);
@@ -196,14 +209,10 @@ private:
     llama_cparams       cparams;
     llama_adapter_cvec  cvec;
     llama_adapter_loras loras;
-    llama_sbatch        sbatch;
 
     llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
 
-    std::unique_ptr<llama_kv_cache_unified> kv_self;
-
-    // TODO: remove
-    bool logits_all = false;
+    std::unique_ptr<llama_memory_i> memory;
 
     // decode output (2-dimensional array: [n_outputs][n_vocab])
     size_t  logits_size = 0; // capacity (of floats) for logits
@@ -230,6 +239,9 @@ private:
 
     ggml_context_ptr ctx_compute;
 
+    // training
+    ggml_opt_context_t opt_ctx = nullptr;
+
     ggml_threadpool_t threadpool       = nullptr;
     ggml_threadpool_t threadpool_batch = nullptr;
 
index 30e550f023a9e323fae37fd364d9fcb8a3745e41..246fa5777deea1f6d4b94581d9b07b258b9434a2 100644 (file)
@@ -30,6 +30,7 @@ struct llama_cparams {
     bool flash_attn;
     bool no_perf;
     bool warmup;
+    bool op_offload;
 
     enum llama_pooling_type pooling_type;
 
index fabb9ca237653db93b2169a2947e859d391b1759..b0e3f63597a76d0481b77e949f823cda749fd561 100644 (file)
@@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
 
         // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
         for (uint32_t i = 0; i < n_kv; ++i) {
-            const uint32_t  cell_id = i + kv_self->head;
-
-            //////////////////////////////////////////////
-            // TODO: this should not mutate the KV cache !
-            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
-
-            // prevent out-of-bound sources
-            if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
-                kv_cell.src = cell_id;
-            }
-
-            data[i] = kv_cell.src;
-
-            // TODO: do not mutate the KV cache
-            // ensure copy only happens once
-            if (kv_cell.src != (int32_t) cell_id) {
-                kv_cell.src = cell_id;
-            }
+            data[i] = kv_self->s_copy(i);
         }
     }
 }
@@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
 
         // clear unused states
         for (int i = 0; i < n_kv; ++i) {
-            const uint32_t  cell_id = i + kv_self->head;
-
-            //////////////////////////////////////////////
-            // TODO: this should not mutate the KV cache !
-            llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
-
-            data[i] = (float) (kv_cell.src >= 0);
-
-            // only clear once
-            if (kv_cell.src < 0) {
-                kv_cell.src = cell_id;
-            }
+            data[i] = kv_self->s_mask(i);
         }
     }
 }
@@ -810,7 +782,7 @@ ggml_tensor * llm_graph_context::build_ffn(
             } break;
     }
 
-    if (type_gate == LLM_FFN_PAR) {
+    if (gate && type_gate == LLM_FFN_PAR) {
         cur = ggml_mul(ctx0, cur, tmp);
         cb(cur, "ffn_gate_par", il);
     }
@@ -999,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
         inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
         //cb(inp->tokens, "inp_tokens", -1);
         ggml_set_input(inp->tokens);
+        res->t_tokens = inp->tokens;
 
         cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
 
@@ -1105,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_s_copy() const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
 
@@ -1122,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_s_mask() const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
 
@@ -1255,8 +1228,19 @@ ggml_tensor * llm_graph_context::build_attn_mha(
         ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
 
         if (v_mla) {
+#if 0
+            // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
+            // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
             cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
             cur = ggml_mul_mat(ctx0, v_mla, cur);
+#else
+            // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
+            // The permutations are noops and only change how the tensor data is interpreted.
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_mul_mat(ctx0, v_mla, cur);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
+#endif
         }
 
         cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
@@ -1436,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn(
 
     // store to KV cache
     {
-        GGML_ASSERT(!kv_self->recurrent);
-
         const auto kv_head = kv_self->head;
 
         GGML_ASSERT(kv_self->size == n_ctx);
@@ -1587,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
          ggml_tensor * state_mask,
              int32_t   n_state,
              int32_t   n_seqs) const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     const auto n_kv    = kv_self->n;
     const auto kv_head = kv_self->head;
@@ -1619,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
          ggml_tensor * state_mask,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     const auto token_shift_count = hparams.token_shift_count;
 
@@ -1640,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
          ggml_tensor * token_shift,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+    const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
     const auto token_shift_count = hparams.token_shift_count;
     const auto n_embd = hparams.n_embd;
index d0c8d32192784a8296b7e82b4a0394fe6465bf2a..832a8c09f2b80eb816326458d100caa2ca262764 100644 (file)
@@ -19,6 +19,7 @@ struct llama_cparams;
 
 class llama_memory_i;
 class llama_kv_cache_unified;
+class llama_kv_cache_recurrent;
 
 // certain models (typically multi-modal) can produce different types of graphs
 enum llm_graph_type {
@@ -186,26 +187,26 @@ public:
 
 class llm_graph_input_s_copy : public llm_graph_input_i {
 public:
-    llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
     virtual ~llm_graph_input_s_copy() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_copy; // I32 [kv_size]
 
-    const llama_kv_cache_unified * kv_self;
+    const llama_kv_cache_recurrent * kv_self;
 };
 
 class llm_graph_input_s_mask : public llm_graph_input_i {
 public:
-    llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
     virtual ~llm_graph_input_s_mask() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_mask; // F32 [1, n_kv]
 
-    const llama_kv_cache_unified * kv_self;
+    const llama_kv_cache_recurrent * kv_self;
 };
 
 class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -297,6 +298,7 @@ class llm_graph_result_i {
 public:
     virtual ~llm_graph_result_i() = default;
 
+    virtual ggml_tensor * get_tokens()      = 0;
     virtual ggml_tensor * get_logits()      = 0;
     virtual ggml_tensor * get_embd()        = 0;
     virtual ggml_tensor * get_embd_pooled() = 0;
@@ -311,6 +313,7 @@ class llm_graph_result : public llm_graph_result_i {
 public:
     virtual ~llm_graph_result() = default;
 
+    ggml_tensor * get_tokens()      override { return t_tokens; }
     ggml_tensor * get_logits()      override { return t_logits; }
     ggml_tensor * get_embd()        override { return t_embd; }
     ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
@@ -327,6 +330,7 @@ public:
     }
 
     // important graph nodes
+    ggml_tensor * t_tokens      = nullptr;
     ggml_tensor * t_logits      = nullptr;
     ggml_tensor * t_embd        = nullptr;
     ggml_tensor * t_embd_pooled = nullptr;
@@ -350,8 +354,8 @@ struct llm_graph_params {
     const llama_cparams & cparams;
     const llama_ubatch  & ubatch;
 
-    ggml_backend_sched * sched;
-    ggml_backend * backend_cpu;
+    ggml_backend_sched_t sched;
+    ggml_backend_t backend_cpu;
 
     const llama_adapter_cvec  * cvec;
     const llama_adapter_loras * loras;
@@ -402,9 +406,9 @@ struct llm_graph_context {
 
     ggml_context * ctx0 = nullptr;
 
-    ggml_backend_sched * sched;
+    ggml_backend_sched_t sched;
 
-    ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
+    ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
 
     const llama_adapter_cvec  * cvec;
     const llama_adapter_loras * loras;
index 7c9d46d8119b39d70bc0c5acb6580d73b7c6dcac..3dcad65bb6a8532fba298648f45119b21fc9b42d 100644 (file)
@@ -4,33 +4,41 @@
 #include "llama-batch.h"
 #include "llama-cparams.h"
 #include "llama-model.h"
+#include "llama-context.h"
 
 #include <algorithm>
 #include <cassert>
+#include <cmath>
 #include <limits>
 #include <map>
 #include <stdexcept>
 
-llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
+//
+// llama_kv_cache_unified
+//
+
+uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
+    // the FA kernels require padding to avoid extra runtime boundary checks
+    return cparams.flash_attn ? 256u : 32u;
 }
 
-bool llama_kv_cache_unified::init(
+llama_kv_cache_unified::llama_kv_cache_unified(
         const llama_model & model,
-      const llama_cparams & cparams,
                 ggml_type   type_k,
                 ggml_type   type_v,
+                     bool   v_trans,
+                     bool   offload,
                  uint32_t   kv_size,
-                     bool   offload) {
+                 uint32_t   padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
     const int32_t n_layer = hparams.n_layer;
 
     has_shift = false;
+    can_shift = true;
 
-    recurrent = llama_model_is_recurrent(&model);
-    v_trans   = !recurrent && !cparams.flash_attn;
-    can_shift = !recurrent;
+    LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
+            __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
 
-    LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n",
-            __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift);
+    GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
 
     head = 0;
     size = kv_size;
@@ -76,23 +84,20 @@ bool llama_kv_cache_unified::init(
 
         const char * dev_name = "CPU";
 
-        ggml_backend_buffer_type_t buft;
+        ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
+
         if (offload) {
             auto * dev = model.dev_layer(i);
             buft = ggml_backend_dev_buffer_type(dev);
 
             dev_name = ggml_backend_dev_name(dev);
-        } else {
-            buft = ggml_backend_cpu_buffer_type();
         }
 
-        LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__,
-                i, n_embd_k_gqa, n_embd_v_gqa, dev_name);
+        LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name);
 
         ggml_context * ctx = ctx_for_buft(buft);
         if (!ctx) {
-            LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
-            return false;
+            throw std::runtime_error("failed to create ggml context for kv cache");
         }
 
         ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
@@ -110,55 +115,28 @@ bool llama_kv_cache_unified::init(
 
         ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
         if (!buf) {
-            LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
-            return false;
+            throw std::runtime_error("failed to allocate buffer for kv cache");
         }
         ggml_backend_buffer_clear(buf, 0);
         LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
         bufs.emplace_back(buf);
     }
 
-    return true;
-}
-
-int32_t llama_kv_cache_unified::get_n_tokens() const {
-    int32_t result = 0;
-
-    for (uint32_t i = 0; i < size; i++) {
-        result += cells[i].seq_id.size();
-    }
-
-    return result;
-}
-
-int32_t llama_kv_cache_unified::get_used_cells() const {
-    return used;
-}
-
-size_t llama_kv_cache_unified::total_size() const {
-    size_t size = 0;
-    for (const auto & buf : bufs) {
-        size += ggml_backend_buffer_get_size(buf.get());
-    }
-
-    return size;
-}
+    {
+        const size_t memory_size_k = size_k_bytes();
+        const size_t memory_size_v = size_v_bytes();
 
-llama_pos llama_kv_cache_unified::pos_max() const {
-    llama_pos pos_max = -1;
-    for (const auto & cell : cells) {
-        pos_max = std::max(pos_max, cell.pos);
+        LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
+                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
     }
-
-    return pos_max;
 }
 
 void llama_kv_cache_unified::clear() {
     for (int32_t i = 0; i < (int32_t) size; ++i) {
         cells[i].pos = -1;
         cells[i].seq_id.clear();
-        cells[i].src = -1;
-        cells[i].tail = -1;
     }
     head = 0;
     used = 0;
@@ -179,35 +157,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    // models like Mamba or RWKV can't have a state partially erased
-    if (recurrent) {
-        if (seq_id >= (int64_t) size) {
-            // could be fatal
-            return false;
-        }
-        if (0 <= seq_id) {
-            int32_t & tail_id = cells[seq_id].tail;
-            if (tail_id >= 0) {
-                const llama_kv_cell & cell = cells[tail_id];
-                // partial intersection is invalid
-                if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
-                    return false;
-                }
-                // invalidate tails which will be cleared
-                if (p0 <= cell.pos && cell.pos < p1) {
-                    tail_id = -1;
-                }
-            }
-        } else {
-            // seq_id is negative, then the range should include everything or nothing
-            if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
-                return false;
-            }
-        }
-
-        return true;
-    }
-
     for (uint32_t i = 0; i < size; ++i) {
         if (cells[i].pos >= p0 && cells[i].pos < p1) {
             if (seq_id < 0) {
@@ -224,7 +173,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
                 }
 
                 cells[i].pos = -1;
-                cells[i].src = -1;
 
                 if (new_head == size) {
                     new_head = i;
@@ -254,34 +202,6 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
         p1 = std::numeric_limits<llama_pos>::max();
     }
 
-    if (recurrent) {
-        if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
-            llama_kv_cell & tail_src = cells[seq_id_src];
-            llama_kv_cell & tail_dst = cells[seq_id_dst];
-            if (tail_dst.tail >= 0) {
-                // clear destination seq_id if it wasn't empty
-                llama_kv_cell & cell_dst = cells[tail_dst.tail];
-
-                cell_dst.seq_id.erase(seq_id_dst);
-                tail_dst.tail = -1;
-                if (cell_dst.seq_id.empty()) {
-                    cell_dst.pos = -1;
-                    cell_dst.delta = -1;
-                    cell_dst.src = -1;
-                    used -= 1;
-                }
-            }
-            if (tail_src.tail >= 0) {
-                llama_kv_cell & cell_src = cells[tail_src.tail];
-
-                cell_src.seq_id.insert(seq_id_dst);
-                tail_dst.tail = tail_src.tail;
-            }
-        }
-
-        return;
-    }
-
     // otherwise, this is the KV of a Transformer-like model
     head = 0;
 
@@ -296,17 +216,12 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
     uint32_t new_head = size;
 
     for (uint32_t i = 0; i < size; ++i) {
-        if (recurrent && (llama_seq_id) i != seq_id) {
-            cells[i].tail = -1;
-        }
-
         if (!cells[i].has_seq_id(seq_id)) {
             if (cells[i].pos >= 0) {
                 used--;
             }
 
             cells[i].pos = -1;
-            cells[i].src = -1;
             cells[i].seq_id.clear();
 
             if (new_head == size){
@@ -344,20 +259,6 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
         return;
     }
 
-    if (recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be shifted
-        if (0 <= seq_id && seq_id < (int64_t) size) {
-            const int32_t tail_id = cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos += delta;
-                }
-            }
-        }
-        return;
-    }
-
     for (uint32_t i = 0; i < size; ++i) {
         if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
             has_shift = true;
@@ -400,21 +301,6 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
         return;
     }
 
-    if (recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be changed
-        if (0 <= seq_id && seq_id < (int64_t) size) {
-            const int32_t tail_id = cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos /= d;
-                }
-            }
-        }
-
-        return;
-    }
-
     for (uint32_t i = 0; i < size; ++i) {
         if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
             has_shift = true;
@@ -440,23 +326,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
-void llama_kv_cache_unified::defrag() {
-    if (!recurrent) {
-        do_defrag = true;
-    }
-}
-
 void llama_kv_cache_unified::restore() {
     if (pending.ranges.empty()) {
         return;
     }
 
-    // TODO: tmp - move to llama_kv_cache_recurrent
-    if (recurrent) {
-        seq_rm(-1, -1, -1);
-        return;
-    }
-
     uint32_t new_head = size;
 
     for (auto & range : pending.ranges) {
@@ -469,7 +343,6 @@ void llama_kv_cache_unified::restore() {
             }
 
             cells[i].pos = -1;
-            cells[i].src = -1;
         }
 
         new_head = std::min(new_head, range.c0);
@@ -481,11 +354,6 @@ void llama_kv_cache_unified::restore() {
 }
 
 void llama_kv_cache_unified::commit() {
-    // TODO: tmp - move to llama_kv_cache_recurrent
-    if (recurrent) {
-        return;
-    }
-
     if (pending.ranges.empty()) {
         LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
                 __func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
@@ -495,183 +363,110 @@ void llama_kv_cache_unified::commit() {
     pending.ranges.clear();
 }
 
-bool llama_kv_cache_unified::get_can_shift() const {
-    return can_shift;
-}
+bool llama_kv_cache_unified::update(llama_context & lctx) {
+    bool need_reserve = false;
 
-bool llama_kv_cache_unified::find_slot(
-       const llama_ubatch & ubatch) {
-    const uint32_t n_tokens = ubatch.n_tokens;
-    const uint32_t n_seqs   = ubatch.n_seqs;
-    const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
+    auto * sched = lctx.get_sched();
 
-    // if we have enough unused cells before the current head ->
-    //   better to start searching from the beginning of the cache, hoping to fill it
-    if (head > used + 2*ubatch.n_tokens) {
-        head = 0;
-    }
+    if (has_shift) {
+        if (!get_can_shift()) {
+            GGML_ABORT("The current KV cache / model configuration does not support K-shift");
+        }
 
-    if (recurrent) {
-        // For recurrent state architectures (like Mamba or RWKV),
-        // each cache cell can store the state for a whole sequence.
-        // A slot should be always be contiguous.
+        LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
 
-        // can only process batches with an equal number of new tokens in each sequence
-        GGML_ASSERT(ubatch.equal_seqs);
+        // apply K-shift if needed
+        if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
+            ggml_backend_sched_reset(sched);
 
-        int32_t min = size - 1;
-        int32_t max = 0;
+            auto * gf = lctx.graph_init();
 
-        // everything should fit if all seq_ids are smaller than the max
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const uint32_t n_seq_id = ubatch.n_seq_id[s];
-            for (uint32_t j = 0; j < n_seq_id; ++j) {
-                const llama_seq_id seq_id = ubatch.seq_id[s][j];
+            auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
 
-                if (seq_id < 0 || (uint32_t) seq_id >= size) {
-                    // too big seq_id
-                    // TODO: would it be possible to resize the cache instead?
-                    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
-                    return false;
-                }
-                if (j > 0) {
-                    llama_kv_cell & seq = cells[seq_id];
-                    if (seq.tail >= 0) {
-                        llama_kv_cell & cell = cells[seq.tail];
-                        // clear cells from seq_ids that become shared
-                        // (should not normally happen, but let's handle it anyway)
-                        cell.seq_id.erase(seq_id);
-                        seq.tail = -1;
-                        if (cell.seq_id.empty()) {
-                            cell.pos = -1;
-                            cell.src = -1;
-                            used -= 1;
-                        }
-                    }
-                }
-            }
+            ggml_backend_sched_alloc_graph(sched, gf);
+
+            res->set_inputs(nullptr);
+
+            lctx.graph_compute(gf, false);
+
+            need_reserve = true;
         }
 
-#ifndef NDEBUG
         {
-            std::vector<int32_t> tails_verif;
-            tails_verif.assign(size, -1);
-            for (uint32_t i = 0; i < size; ++i) {
-                llama_kv_cell & cell = cells[i];
-                for (llama_seq_id seq_id : cell.seq_id) {
-                    if (tails_verif[seq_id] != -1) {
-                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
-                    }
-                    tails_verif[seq_id] = i;
-                }
-            }
+            has_shift = false;
+
             for (uint32_t i = 0; i < size; ++i) {
-                if (tails_verif[i] != cells[i].tail) {
-                    LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
-                }
+                cells[i].delta = 0;
             }
         }
-#endif
+    }
 
-        // find next empty cell
-        uint32_t next_empty_cell = head;
+    if (do_defrag) {
+        LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
 
-        for (uint32_t i = 0; i < size; ++i) {
-            if (next_empty_cell >= size) { next_empty_cell -= size; }
-            llama_kv_cell & cell = cells[next_empty_cell];
-            if (cell.is_empty()) { break; }
-            next_empty_cell += 1;
-        }
+        if (defrag_prepare(lctx.graph_max_nodes())) {
+            ggml_backend_sched_reset(sched);
 
-        // find usable cell range
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][0];
-            llama_kv_cell & seq_meta = cells[seq_id];
-            bool has_cell = false;
-            if (seq_meta.tail >= 0) {
-                llama_kv_cell & cell = cells[seq_meta.tail];
-                GGML_ASSERT(cell.has_seq_id(seq_id));
-                // does this seq_id "own" the cell?
-                if (cell.seq_id.size() == 1) { has_cell = true; }
-            }
-            if (!has_cell) {
-                llama_kv_cell & empty_cell = cells[next_empty_cell];
-                GGML_ASSERT(empty_cell.is_empty());
-                // copy old tail into the empty cell
-                if (seq_meta.tail >= 0) {
-                    llama_kv_cell & orig_cell = cells[seq_meta.tail];
-                    empty_cell.pos = orig_cell.pos;
-                    empty_cell.src = orig_cell.src;
-                    orig_cell.seq_id.erase(seq_id);
-                    empty_cell.seq_id.insert(seq_id); // will be overwritten
-                }
-                seq_meta.tail = next_empty_cell;
-                // find next empty cell
-                if (s + 1 < n_seqs) {
-                    next_empty_cell += 1;
-                    for (uint32_t i = 0; i < size; ++i) {
-                        if (next_empty_cell >= size) { next_empty_cell -= size; }
-                        llama_kv_cell & cell = cells[next_empty_cell];
-                        if (cell.is_empty()) { break; }
-                        next_empty_cell += 1;
-                    }
-                }
-            }
-            if (min > seq_meta.tail) { min = seq_meta.tail; }
-            if (max < seq_meta.tail) { max = seq_meta.tail; }
-        }
+            auto * gf = lctx.graph_init();
 
-        // gather and re-order
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            int32_t dst_id = s + min;
-            int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
-            if (dst_id != src_id) {
-                llama_kv_cell & dst_cell = cells[dst_id];
-                llama_kv_cell & src_cell = cells[src_id];
+            auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
 
-                std::swap(dst_cell.pos, src_cell.pos);
-                std::swap(dst_cell.src, src_cell.src);
-                std::swap(dst_cell.seq_id, src_cell.seq_id);
+            ggml_backend_sched_alloc_graph(sched, gf);
 
-                // swap tails (assuming they NEVER overlap)
-                for (const llama_seq_id seq_id : src_cell.seq_id) {
-                    cells[seq_id].tail = src_id;
-                }
-                for (const llama_seq_id seq_id : dst_cell.seq_id) {
-                    cells[seq_id].tail = dst_id;
-                }
-            }
-        }
+            res->set_inputs(nullptr);
 
-        // update the pos of the used seqs
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
-            int32_t cell_id = s + min;
-            llama_kv_cell & cell = cells[cell_id];
+            lctx.graph_compute(gf, false);
 
-            if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
-                // What should happen when the pos backtracks or skips a value?
-                // Clearing the state mid-batch would require special-casing which isn't done.
-                LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
-                    __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
-            }
-            cell.pos = last_pos;
-            cell.seq_id.clear();
-            for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
-                const llama_seq_id seq_id = ubatch.seq_id[s][j];
-                cell.seq_id.insert(seq_id);
-                cells[seq_id].tail = cell_id;
-            }
+            need_reserve = true;
         }
 
-        // allow getting the range of used cells, from head to head + n
-        head = min;
-        n    = max - min + 1;
-        used = std::count_if(cells.begin(), cells.end(),
-            [](const llama_kv_cell& cell){ return !cell.is_empty(); });
+        do_defrag = false;
+    }
+
+    return need_reserve;
+}
+
+void llama_kv_cache_unified::defrag_sched(float thold) {
+    // - do not defrag small contexts (i.e. < 2048 tokens)
+    // - count the padding towards the number of used tokens
+    const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
+
+    // queue defragmentation for next llama_kv_cache_update
+    if (fragmentation > thold) {
+        LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
+
+        do_defrag = true;
+    }
+}
+
+void llama_kv_cache_unified::set_full() {
+    n = size;
+}
+
+llama_sbatch llama_kv_cache_unified::sbatch_init(
+        const llama_batch & batch,
+        bool logits_all) {
+    return llama_sbatch(batch, hparams.n_embd, true, logits_all);
+}
+
+llama_ubatch llama_kv_cache_unified::ubatch_next(
+        llama_sbatch & sbatch,
+        uint32_t n_ubatch,
+        bool embd_pooled) const {
+    GGML_UNUSED(embd_pooled);
+    return sbatch.split_simple(n_ubatch);
+}
+
+bool llama_kv_cache_unified::find_slot(
+       const llama_ubatch & ubatch) {
+    const uint32_t n_tokens = ubatch.n_tokens;
+    const uint32_t n_seqs   = ubatch.n_seqs;
+    const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
 
-        // sanity check
-        return n >= n_seqs;
+    // if we have enough unused cells before the current head ->
+    //   better to start searching from the beginning of the cache, hoping to fill it
+    if (head > used + 2*ubatch.n_tokens) {
+        head = 0;
     }
 
     // otherwise, one cell per token.
@@ -725,24 +520,50 @@ bool llama_kv_cache_unified::find_slot(
 
     pending.ranges.push_back({head, head + n_tokens});
 
+    // a heuristic, to avoid attending the full cache if it is not yet utilized
+    // after enough generations, the benefit from this heuristic disappears
+    // if we start defragmenting the cache, the benefit from this will be more important
+    n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
+
+    //printf("n = %5d, used = %5d, head = %5d\n", n, used, head);
+
     return true;
 }
 
-uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
-    // the FA kernels require padding to avoid extra runtime boundary checks
-    return cparams.flash_attn ? 256u : 32u;
+int32_t llama_kv_cache_unified::get_n_tokens() const {
+    int32_t result = 0;
+
+    for (uint32_t i = 0; i < size; i++) {
+        result += cells[i].seq_id.size();
+    }
+
+    return result;
 }
 
-uint32_t llama_kv_cache_unified::cell_max() const {
-    for (uint32_t i = size; i > 0; --i) {
-        const llama_kv_cell & cell = cells[i - 1];
+int32_t llama_kv_cache_unified::get_used_cells() const {
+    return used;
+}
 
-        if (cell.pos >= 0 && !cell.is_empty()) {
-            return i;
-        }
+bool llama_kv_cache_unified::get_can_shift() const {
+    return can_shift;
+}
+
+llama_pos llama_kv_cache_unified::get_pos_max() const {
+    llama_pos pos_max = -1;
+    for (const auto & cell : cells) {
+        pos_max = std::max(pos_max, cell.pos);
     }
 
-    return 0;
+    return pos_max;
+}
+
+size_t llama_kv_cache_unified::total_size() const {
+    size_t size = 0;
+    for (const auto & buf : bufs) {
+        size += ggml_backend_buffer_get_size(buf.get());
+    }
+
+    return size;
 }
 
 size_t llama_kv_cache_unified::size_k_bytes() const {
@@ -765,68 +586,331 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
     return size_v_bytes;
 }
 
-bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
-    const uint32_t n_layer = hparams.n_layer;
+ggml_tensor * llama_kv_cache_unified::build_rope_shift(
+        const llama_cparams & cparams,
+               ggml_context * ctx,
+                ggml_tensor * cur,
+                ggml_tensor * shift,
+                ggml_tensor * factors,
+                      float   freq_base,
+                      float   freq_scale) const {
+    const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
 
-    const uint32_t n_kv   = cell_max();
-    const uint32_t n_used = used;
+    const auto & yarn_ext_factor = cparams.yarn_ext_factor;
+    const auto & yarn_beta_fast  = cparams.yarn_beta_fast;
+    const auto & yarn_beta_slow  = cparams.yarn_beta_slow;
 
-    assert(n_used <= n_kv);
+    const auto & n_rot     = hparams.n_rot;
+    const auto & rope_type = hparams.rope_type;
 
-    //const int64_t t_start = ggml_time_us();
+    // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
+    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
 
-    // number of cells moved
-    uint32_t n_moves = 0;
+    ggml_tensor * tmp;
 
-    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
-    //   - source view, destination view, copy operation
-    //   - x2 for keys and values
-    //const uint32_t max_moves = max_nodes()/(6*n_layer);
-    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
+    if (ggml_is_quantized(cur->type)) {
+        // dequantize to f32 -> RoPE -> quantize back
+        tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
 
-    // determine which KV cells to move where
-    //
-    //  cell i moves to ids[i]
-    //
-    //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
-    //
-    auto & ids = defrag_info.ids;
+        tmp = ggml_rope_ext(ctx, tmp,
+                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
 
-    ids.clear();
-    ids.resize(n_kv, n_kv);
+        tmp = ggml_cpy(ctx, tmp, cur);
+    } else {
+        // we rotate only the first n_rot dimensions
+        tmp = ggml_rope_ext_inplace(ctx, cur,
+                shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
+    }
 
-    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
-        const auto & cell0 = cells[i0];
+    return tmp;
+}
 
-        if (!cell0.is_empty()) {
-            ids[i0] = i0;
+class llm_graph_input_k_shift : public llm_graph_input_i {
+public:
+    llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    virtual ~llm_graph_input_k_shift() = default;
 
-            continue;
-        }
+    void set_input(const llama_ubatch * ubatch) override;
 
-        // found a hole - fill it with data from the end of the cache
+    ggml_tensor * k_shift; // I32 [kv_size]
 
-        uint32_t nh = 1;
+    const llama_kv_cache_unified * kv_self;
+};
 
-        // determine the size of the hole
-        while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
-            nh++;
+void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
+    GGML_UNUSED(ubatch);
+
+    if (k_shift) {
+        assert(ggml_backend_buffer_is_host(k_shift->buffer));
+
+        int32_t * data = (int32_t *) k_shift->data;
+
+        for (uint32_t i = 0; i < kv_self->size; ++i) {
+            data[i] = kv_self->cells[i].delta;
         }
+    }
+}
 
-        uint32_t nf = 0;
-        uint32_t is = n_kv - 1;
+llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
+        const llama_cparams & cparams,
+               ggml_context * ctx,
+                ggml_cgraph * gf) const {
+    auto res = std::make_unique<llm_graph_result>();
 
-        // starting from the end, find nh non-empty cells
-        for (; is > i0; --is) {
-            const auto & cell1 = cells[is];
+    const auto & n_layer = hparams.n_layer;
 
-            if (cell1.is_empty() || ids[is] != n_kv) {
-                continue;
-            }
+    const auto & n_embd_head_k = hparams.n_embd_head_k;
+  //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
-            // non-empty cell which is not yet moved
-            nf++;
+    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
+
+    //GGML_ASSERT(kv_self->size == n_ctx);
+
+    auto inp = std::make_unique<llm_graph_input_k_shift>(this);
+
+    inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx);
+    ggml_set_input(inp->k_shift);
+
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const int64_t n_head_kv    = hparams.n_head_kv(il);
+        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+
+        const bool is_swa = hparams.is_swa(il);
+
+        // note: the swa rope params could become part of the cparams in the future
+        //       if we decide to make them configurable, like the non-sliding ones
+        const float freq_base_l  = is_swa ? hparams.rope_freq_base_train_swa  : cparams.rope_freq_base;
+        const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
+
+        ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+
+        ggml_tensor * k =
+            ggml_view_3d(ctx, k_l[il],
+                n_embd_head_k, n_head_kv, size,
+                ggml_row_size(k_l[il]->type, n_embd_head_k),
+                ggml_row_size(k_l[il]->type, n_embd_k_gqa),
+                0);
+
+        ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
+
+        ggml_build_forward_expand(gf, cur);
+    }
+
+    res->add_input(std::move(inp));
+
+    return res;
+}
+
+llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
+        const llama_cparams & cparams,
+               ggml_context * ctx,
+                ggml_cgraph * gf) const {
+    auto res = std::make_unique<llm_graph_result>();
+
+    const auto & ids = defrag_info.ids;
+
+#if 0
+    // CPU defrag
+    //
+    // TODO: optimizations are possible:
+    //       - multiple threads
+    //       - avoid copying to the host memory when already there
+    //
+    // likely not worth the effort, as we have ggml_graph based defrag
+    //
+
+    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
+    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+
+    const uint32_t kv_size = size;
+
+    std::vector<uint8_t> buf_k;
+    std::vector<uint8_t> buf_v;
+
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        const size_t k_size     = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
+
+        const size_t v_size_el = ggml_type_size(v_l[il]->type);
+        const size_t v_size    = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
+
+        buf_k.resize(k_size);
+        buf_v.resize(v_size);
+
+        ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
+        ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
+
+        // batch move [i, i+nm) to [id, id+nm)
+        // note: cells can move only to a lower index
+        for (uint32_t i = 0; i < n_kv; ++i) {
+            const uint32_t id = ids[i];
+
+            if (i == id || id == n_kv) {
+                continue;
+            }
+
+            uint32_t nm = 1;
+
+            while (i + nm < n_kv && ids[i + nm] == id + nm) {
+                nm++;
+            }
+
+            // move keys
+            {
+                const int64_t os =  i*k_size_row;
+                const int64_t od = id*k_size_row;
+
+                memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
+            }
+
+            // move values (note: they are transposed)
+            {
+                const int64_t os =  i;
+                const int64_t od = id;
+
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
+                }
+            }
+
+            i += nm - 1;
+        }
+
+        ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
+        ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
+    }
+#else
+    for (uint32_t i = 0; i < ids.size(); ++i) {
+        const uint32_t id = ids[i];
+
+        if (i == id || id == ids.size()) {
+            continue;
+        }
+
+        uint32_t nm = 1;
+
+        while (i + nm < ids.size() && ids[i + nm] == id + nm) {
+            nm++;
+        }
+
+        for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
+            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+            const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
+            ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il],
+                    n_embd_k_gqa, nm,
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa),
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*i));
+
+            ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il],
+                    n_embd_k_gqa, nm,
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa),
+                    ggml_row_size(k_l[il]->type, n_embd_k_gqa*id));
+
+            ggml_tensor * view_v_src;
+            ggml_tensor * view_v_dst;
+
+            if (cparams.flash_attn) {
+                // NOTE: the V cache is not transposed when using flash attention
+                view_v_src = ggml_view_2d(ctx, v_l[il],
+                        n_embd_v_gqa, nm,
+                        ggml_row_size(v_l[il]->type, n_embd_v_gqa),
+                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*i));
+
+                view_v_dst = ggml_view_2d(ctx, v_l[il],
+                        n_embd_v_gqa, nm,
+                        ggml_row_size(v_l[il]->type, n_embd_v_gqa),
+                        ggml_row_size(v_l[il]->type, n_embd_v_gqa*id));
+            } else {
+                view_v_src = ggml_view_2d(ctx, v_l[il],
+                        nm, n_embd_v_gqa,
+                        ggml_row_size(v_l[il]->type, size),
+                        ggml_row_size(v_l[il]->type, i));
+
+                view_v_dst = ggml_view_2d(ctx, v_l[il],
+                        nm, n_embd_v_gqa,
+                        ggml_row_size(v_l[il]->type, size),
+                        ggml_row_size(v_l[il]->type, id));
+            }
+
+            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst));
+            ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst));
+        }
+
+        i += nm - 1;
+    }
+
+    //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
+#endif
+
+    return res;
+}
+
+bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
+    const uint32_t n_layer = hparams.n_layer;
+
+    const uint32_t n_kv   = cell_max();
+    const uint32_t n_used = used;
+
+    assert(n_used <= n_kv);
+
+    //const int64_t t_start = ggml_time_us();
+
+    // number of cells moved
+    uint32_t n_moves = 0;
+
+    // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag)
+    //   - source view, destination view, copy operation
+    //   - x2 for keys and values
+    //const uint32_t max_moves = max_nodes()/(6*n_layer);
+    // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
+    const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
+
+    // determine which KV cells to move where
+    //
+    //  cell i moves to ids[i]
+    //
+    //  if ids[i] == i || ids[i] == n_kv, then cell i is not moved
+    //
+    auto & ids = defrag_info.ids;
+
+    ids.clear();
+    ids.resize(n_kv, n_kv);
+
+    for (uint32_t i0 = 0; i0 < n_used; ++i0) {
+        const auto & cell0 = cells[i0];
+
+        if (!cell0.is_empty()) {
+            ids[i0] = i0;
+
+            continue;
+        }
+
+        // found a hole - fill it with data from the end of the cache
+
+        uint32_t nh = 1;
+
+        // determine the size of the hole
+        while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
+            nh++;
+        }
+
+        uint32_t nf = 0;
+        uint32_t is = n_kv - 1;
+
+        // starting from the end, find nh non-empty cells
+        for (; is > i0; --is) {
+            const auto & cell1 = cells[is];
+
+            if (cell1.is_empty() || ids[is] != n_kv) {
+                continue;
+            }
+
+            // non-empty cell which is not yet moved
+            nf++;
 
             if (nf == nh) {
                 break;
@@ -867,7 +951,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
             cells[i0 + nf] = cell1;
 
             // clear the old cell and move the head there
-            cell1 = llama_kv_cell();
+            cell1 = kv_cell();
             head = n_used;
 
             if (!cont) {
@@ -895,13 +979,25 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
         return false;
     }
 
-    LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves);
+    LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
 
-    LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer);
+    LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
 
     return true;
 }
 
+uint32_t llama_kv_cache_unified::cell_max() const {
+    for (uint32_t i = size; i > 0; --i) {
+        const kv_cell & cell = cells[i - 1];
+
+        if (cell.pos >= 0 && !cell.is_empty()) {
+            return i;
+        }
+    }
+
+    return 0;
+}
+
 void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
     std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
     uint32_t cell_count = 0;
@@ -1110,7 +1206,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         clear();
 
         for (uint32_t i = 0; i < cell_count; ++i) {
-            llama_kv_cell & cell = cells[i];
+            kv_cell & cell = cells[i];
 
             llama_pos pos;
             uint32_t  n_seq_id;
@@ -1133,15 +1229,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
                 }
 
                 cell.seq_id.insert(seq_id);
-
-                if (recurrent) {
-                    int32_t & tail = cells[seq_id].tail;
-                    if (tail != -1) {
-                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
-                        return false;
-                    }
-                    tail = i;
-                }
             }
         }
 
@@ -1149,14 +1236,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
         used = cell_count;
     }
 
-    if (recurrent) {
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            uint32_t cell_id = head + i;
-            // make sure the recurrent states will keep their restored state
-            cells[cell_id].src = cell_id;
-        }
-    }
-
     return true;
 }
 
@@ -1174,7 +1253,1034 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
         return false;
     }
-    if (v_trans != (bool) v_trans) {
+    if (this->v_trans != (bool) v_trans) {
+        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+        return false;
+    }
+
+    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
+        // Read type of key
+        int32_t k_type_i_ref;
+        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
+        const int32_t k_type_i = (int32_t) k_l[il]->type;
+        if (k_type_i != k_type_i_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+            return false;
+        }
+
+        // Read row size of key
+        uint64_t k_size_row_ref;
+        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
+        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        if (k_size_row != k_size_row_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+            return false;
+        }
+
+        if (cell_count) {
+            // Read and set the keys for the whole cell range
+            ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+        }
+    }
+
+    if (!this->v_trans) {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read row size of value
+            uint64_t v_size_row_ref;
+            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
+            const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
+            if (v_size_row != v_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // Read and set the values for the whole cell range
+                ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+            }
+        }
+    } else {
+        // For each layer, read the values for each cell (transposed)
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Read type of value
+            int32_t v_type_i_ref;
+            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            if (v_type_i != v_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                return false;
+            }
+
+            // Read element size of value
+            uint32_t v_size_el_ref;
+            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
+            const size_t v_size_el = ggml_type_size(v_l[il]->type);
+            if (v_size_el != v_size_el_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+                return false;
+            }
+
+            // Read GQA embedding size
+            uint32_t n_embd_v_gqa_ref;
+            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // For each row in the transposed matrix, read the values for the whole cell range
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    const size_t dst_offset = (head + j * size) * v_size_el;
+                    ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                }
+            }
+        }
+    }
+
+    return true;
+}
+
+//
+// llama_kv_cache_recurrent
+//
+
+llama_kv_cache_recurrent::llama_kv_cache_recurrent(
+        const llama_model & model,
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   offload,
+                 uint32_t   kv_size) : hparams(model.hparams) {
+    const int32_t n_layer = hparams.n_layer;
+
+    LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
+            __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
+
+    head = 0;
+    size = kv_size;
+    used = 0;
+
+    this->type_k = type_k;
+    this->type_v = type_v;
+
+    cells.clear();
+    cells.resize(kv_size);
+
+    // create a context for each buffer type
+    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            ggml_init_params params = {
+                /*.mem_size   =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+
+            ggml_context * ctx = ggml_init(params);
+            if (!ctx) {
+                return nullptr;
+            }
+
+            ctx_map[buft] = ctx;
+            ctxs.emplace_back(ctx);
+
+            return ctx;
+        }
+
+        return it->second;
+    };
+
+    k_l.reserve(n_layer);
+    v_l.reserve(n_layer);
+
+    for (int i = 0; i < n_layer; i++) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
+
+        const char * dev_name = "CPU";
+
+        ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
+
+        if (offload) {
+            auto * dev = model.dev_layer(i);
+            buft = ggml_backend_dev_buffer_type(dev);
+
+            dev_name = ggml_backend_dev_name(dev);
+        }
+
+        LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
+
+        ggml_context * ctx = ctx_for_buft(buft);
+        if (!ctx) {
+            throw std::runtime_error("failed to create ggml context for kv cache");
+        }
+
+        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
+        ggml_format_name(k, "cache_k_l%d", i);
+        ggml_format_name(v, "cache_v_l%d", i);
+        k_l.push_back(k);
+        v_l.push_back(v);
+    }
+
+    // allocate tensors and initialize the buffers to avoid NaNs in the padding
+    for (auto it : ctx_map) {
+        auto * buft = it.first;
+        auto * ctx  = it.second;
+
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+        if (!buf) {
+            throw std::runtime_error("failed to allocate buffer for kv cache");
+        }
+        ggml_backend_buffer_clear(buf, 0);
+        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+        bufs.emplace_back(buf);
+    }
+
+    {
+        const size_t memory_size_k = size_k_bytes();
+        const size_t memory_size_v = size_v_bytes();
+
+        LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
+                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
+    }
+}
+
+void llama_kv_cache_recurrent::clear() {
+    for (int32_t i = 0; i < (int32_t) size; ++i) {
+        cells[i].pos = -1;
+        cells[i].seq_id.clear();
+        cells[i].src = -1;
+        cells[i].tail = -1;
+    }
+    head = 0;
+    used = 0;
+
+    for (auto & buf : bufs) {
+        ggml_backend_buffer_clear(buf.get(), 0);
+    }
+}
+
+bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    uint32_t new_head = size;
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // models like Mamba or RWKV can't have a state partially erased
+    if (seq_id >= (int64_t) size) {
+        // could be fatal
+        return false;
+    }
+    if (0 <= seq_id) {
+        int32_t & tail_id = cells[seq_id].tail;
+        if (tail_id >= 0) {
+            const kv_cell & cell = cells[tail_id];
+            // partial intersection is invalid
+            if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+                return false;
+            }
+            // invalidate tails which will be cleared
+            if (p0 <= cell.pos && cell.pos < p1) {
+                tail_id = -1;
+            }
+        }
+    } else {
+        // seq_id is negative, then the range should include everything or nothing
+        if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
+            return false;
+        }
+    }
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].pos >= p0 && cells[i].pos < p1) {
+            if (seq_id < 0) {
+                cells[i].seq_id.clear();
+            } else if (cells[i].has_seq_id(seq_id)) {
+                cells[i].seq_id.erase(seq_id);
+            } else {
+                continue;
+            }
+            if (cells[i].is_empty()) {
+                // keep count of the number of used cells
+                if (cells[i].pos >= 0) {
+                    used--;
+                }
+                cells[i].pos = -1;
+                cells[i].src = -1;
+                if (new_head == size) {
+                    new_head = i;
+                }
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != size && new_head < head) {
+        head = new_head;
+    }
+
+    return true;
+}
+
+void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    if (seq_id_src == seq_id_dst) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
+        kv_cell & tail_src = cells[seq_id_src];
+        kv_cell & tail_dst = cells[seq_id_dst];
+        if (tail_dst.tail >= 0) {
+            // clear destination seq_id if it wasn't empty
+            kv_cell & cell_dst = cells[tail_dst.tail];
+
+            cell_dst.seq_id.erase(seq_id_dst);
+            tail_dst.tail = -1;
+            if (cell_dst.seq_id.empty()) {
+                cell_dst.pos = -1;
+                cell_dst.src = -1;
+                used -= 1;
+            }
+        }
+        if (tail_src.tail >= 0) {
+            kv_cell & cell_src = cells[tail_src.tail];
+
+            cell_src.seq_id.insert(seq_id_dst);
+            tail_dst.tail = tail_src.tail;
+        }
+    }
+}
+
+void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
+    uint32_t new_head = size;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if ((llama_seq_id) i != seq_id) {
+            cells[i].tail = -1;
+        }
+
+        if (!cells[i].has_seq_id(seq_id)) {
+            if (cells[i].pos >= 0) {
+                used--;
+            }
+
+            cells[i].pos = -1;
+            cells[i].src = -1;
+            cells[i].seq_id.clear();
+
+            if (new_head == size){
+                new_head = i;
+            }
+        } else {
+            cells[i].seq_id.clear();
+            cells[i].seq_id.insert(seq_id);
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != size && new_head < head) {
+        head = new_head;
+    }
+}
+
+void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
+    if (delta == 0) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the
+    if (p0 == p1) {
+        return;
+    }
+
+    // for Mamba-like or RWKV models, only the pos needs to be shifted
+    if (0 <= seq_id && seq_id < (int64_t) size) {
+        const int32_t tail_id = cells[seq_id].tail;
+        if (tail_id >= 0) {
+            kv_cell & cell = cells[tail_id];
+            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                cell.pos += delta;
+            }
+        }
+    }
+}
+
+void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    if (d == 1) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the cache.
+    if (p0 == p1) {
+        return;
+    }
+
+    // for Mamba-like or RWKV models, only the pos needs to be changed
+    if (0 <= seq_id && seq_id < (int64_t) size) {
+        const int32_t tail_id = cells[seq_id].tail;
+        if (tail_id >= 0) {
+            kv_cell & cell = cells[tail_id];
+            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                cell.pos /= d;
+            }
+        }
+    }
+}
+
+llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
+    llama_pos result = 0;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].has_seq_id(seq_id)) {
+            result = std::max(result, cells[i].pos);
+        }
+    }
+
+    return result;
+}
+
+void llama_kv_cache_recurrent::restore() {
+    if (pending.ranges.empty()) {
+        return;
+    }
+
+    seq_rm(-1, -1, -1);
+}
+
+void llama_kv_cache_recurrent::commit() {
+    pending.ranges.clear();
+}
+
+bool llama_kv_cache_recurrent::update(llama_context & lctx) {
+    GGML_UNUSED(lctx);
+    return false;
+}
+
+void llama_kv_cache_recurrent::defrag_sched(float thold) {
+    GGML_UNUSED(thold);
+    // noop
+}
+
+void llama_kv_cache_recurrent::set_full() {
+    n = size;
+}
+
+llama_sbatch llama_kv_cache_recurrent::sbatch_init(
+        const llama_batch & batch,
+        bool logits_all) {
+    return llama_sbatch(batch, hparams.n_embd, false, logits_all);
+}
+
+llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
+    if (embd_pooled) {
+        // Pooled embeddings cannot be split across ubatches (yet)
+        return sbatch.split_seq(n_ubatch);
+    }
+
+    return sbatch.split_equal(n_ubatch);
+}
+
+bool llama_kv_cache_recurrent::find_slot(
+       const llama_ubatch & ubatch) {
+    const uint32_t n_tokens = ubatch.n_tokens;
+    const uint32_t n_seqs   = ubatch.n_seqs;
+
+    const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    // if we have enough unused cells before the current head ->
+    //   better to start searching from the beginning of the cache, hoping to fill it
+    if (head > used + 2*n_tokens) {
+        head = 0;
+    }
+
+    // For recurrent state architectures (like Mamba or RWKV),
+    // each cache cell can store the state for a whole sequence.
+    // A slot should be always be contiguous.
+
+    // can only process batches with an equal number of new tokens in each sequence
+    GGML_ASSERT(ubatch.equal_seqs);
+
+    int32_t min = size - 1;
+    int32_t max = 0;
+
+    // everything should fit if all seq_ids are smaller than the max
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const uint32_t n_seq_id = ubatch.n_seq_id[s];
+        for (uint32_t j = 0; j < n_seq_id; ++j) {
+            const llama_seq_id seq_id = ubatch.seq_id[s][j];
+
+            if (seq_id < 0 || (uint32_t) seq_id >= size) {
+                // too big seq_id
+                // TODO: would it be possible to resize the cache instead?
+                LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
+                return false;
+            }
+            if (j > 0) {
+                kv_cell & seq = cells[seq_id];
+                if (seq.tail >= 0) {
+                    kv_cell & cell = cells[seq.tail];
+                    // clear cells from seq_ids that become shared
+                    // (should not normally happen, but let's handle it anyway)
+                    cell.seq_id.erase(seq_id);
+                    seq.tail = -1;
+                    if (cell.seq_id.empty()) {
+                        cell.pos = -1;
+                        cell.src = -1;
+                        used -= 1;
+                    }
+                }
+            }
+        }
+    }
+
+#ifndef NDEBUG
+    {
+        std::vector<int32_t> tails_verif;
+        tails_verif.assign(size, -1);
+        for (uint32_t i = 0; i < size; ++i) {
+            kv_cell & cell = cells[i];
+            for (llama_seq_id seq_id : cell.seq_id) {
+                if (tails_verif[seq_id] != -1) {
+                    LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
+                }
+                tails_verif[seq_id] = i;
+            }
+        }
+        for (uint32_t i = 0; i < size; ++i) {
+            if (tails_verif[i] != cells[i].tail) {
+                LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
+            }
+        }
+    }
+#endif
+
+    // find next empty cell
+    uint32_t next_empty_cell = head;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (next_empty_cell >= size) { next_empty_cell -= size; }
+        kv_cell & cell = cells[next_empty_cell];
+        if (cell.is_empty()) { break; }
+        next_empty_cell += 1;
+    }
+
+    // find usable cell range
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const llama_seq_id seq_id = ubatch.seq_id[s][0];
+        kv_cell & seq_meta = cells[seq_id];
+        bool has_cell = false;
+        if (seq_meta.tail >= 0) {
+            kv_cell & cell = cells[seq_meta.tail];
+            GGML_ASSERT(cell.has_seq_id(seq_id));
+            // does this seq_id "own" the cell?
+            if (cell.seq_id.size() == 1) { has_cell = true; }
+        }
+        if (!has_cell) {
+            kv_cell & empty_cell = cells[next_empty_cell];
+            GGML_ASSERT(empty_cell.is_empty());
+            // copy old tail into the empty cell
+            if (seq_meta.tail >= 0) {
+                kv_cell & orig_cell = cells[seq_meta.tail];
+                empty_cell.pos = orig_cell.pos;
+                empty_cell.src = orig_cell.src;
+                orig_cell.seq_id.erase(seq_id);
+                empty_cell.seq_id.insert(seq_id); // will be overwritten
+            }
+            seq_meta.tail = next_empty_cell;
+            // find next empty cell
+            if (s + 1 < n_seqs) {
+                next_empty_cell += 1;
+                for (uint32_t i = 0; i < size; ++i) {
+                    if (next_empty_cell >= size) { next_empty_cell -= size; }
+                    kv_cell & cell = cells[next_empty_cell];
+                    if (cell.is_empty()) { break; }
+                    next_empty_cell += 1;
+                }
+            }
+        }
+        if (min > seq_meta.tail) { min = seq_meta.tail; }
+        if (max < seq_meta.tail) { max = seq_meta.tail; }
+    }
+
+    // gather and re-order
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        int32_t dst_id = s + min;
+        int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
+        if (dst_id != src_id) {
+            kv_cell & dst_cell = cells[dst_id];
+            kv_cell & src_cell = cells[src_id];
+
+            std::swap(dst_cell.pos, src_cell.pos);
+            std::swap(dst_cell.src, src_cell.src);
+            std::swap(dst_cell.seq_id, src_cell.seq_id);
+
+            // swap tails (assuming they NEVER overlap)
+            for (const llama_seq_id seq_id : src_cell.seq_id) {
+                cells[seq_id].tail = src_id;
+            }
+            for (const llama_seq_id seq_id : dst_cell.seq_id) {
+                cells[seq_id].tail = dst_id;
+            }
+        }
+    }
+
+    // update the pos of the used seqs
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
+        int32_t cell_id = s + min;
+        kv_cell & cell = cells[cell_id];
+
+        if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
+            // What should happen when the pos backtracks or skips a value?
+            // Clearing the state mid-batch would require special-casing which isn't done.
+            LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
+                __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
+        }
+        cell.pos = last_pos;
+        cell.seq_id.clear();
+        for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
+            const llama_seq_id seq_id = ubatch.seq_id[s][j];
+            cell.seq_id.insert(seq_id);
+            cells[seq_id].tail = cell_id;
+        }
+    }
+
+    // allow getting the range of used cells, from head to head + n
+    head = min;
+    n    = max - min + 1;
+    used = std::count_if(cells.begin(), cells.end(),
+        [](const kv_cell & cell){ return !cell.is_empty(); });
+
+    // sanity check
+    return n >= n_seqs;
+}
+
+int32_t llama_kv_cache_recurrent::get_n_tokens() const {
+    int32_t result = 0;
+
+    for (uint32_t i = 0; i < size; i++) {
+        result += cells[i].seq_id.size();
+    }
+
+    return result;
+}
+
+int32_t llama_kv_cache_recurrent::get_used_cells() const {
+    return used;
+}
+
+llama_pos llama_kv_cache_recurrent::get_pos_max() const {
+    llama_pos pos_max = -1;
+    for (const auto & cell : cells) {
+        pos_max = std::max(pos_max, cell.pos);
+    }
+
+    return pos_max;
+}
+
+bool llama_kv_cache_recurrent::get_can_shift() const {
+    return false;
+}
+
+int32_t llama_kv_cache_recurrent::s_copy(int i) const {
+    const uint32_t cell_id = i + head;
+
+    //////////////////////////////////////////////
+    // TODO: this should not mutate the KV cache !
+    kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
+
+    // prevent out-of-bound sources
+    if (cell.src < 0 || (uint32_t) cell.src >= size) {
+        cell.src = cell_id;
+    }
+
+    int32_t res = cell.src;
+
+    // TODO: do not mutate the KV cache
+    // ensure copy only happens once
+    if (cell.src != (int32_t) cell_id) {
+        cell.src = cell_id;
+    }
+
+    return res;
+}
+
+float llama_kv_cache_recurrent::s_mask(int i) const {
+    const uint32_t cell_id = i + head;
+
+    //////////////////////////////////////////////
+    // TODO: this should not mutate the KV cache !
+    kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
+
+    float res = (float) (cell.src >= 0);
+
+    // only clear once
+    if (cell.src < 0) {
+        cell.src = cell_id;
+    }
+
+    return res;
+}
+
+uint32_t llama_kv_cache_recurrent::cell_max() const {
+    for (uint32_t i = size; i > 0; --i) {
+        const kv_cell & cell = cells[i - 1];
+
+        if (cell.pos >= 0 && !cell.is_empty()) {
+            return i;
+        }
+    }
+
+    return 0;
+}
+
+size_t llama_kv_cache_recurrent::total_size() const {
+    size_t size = 0;
+    for (const auto & buf : bufs) {
+        size += ggml_backend_buffer_get_size(buf.get());
+    }
+
+    return size;
+}
+
+size_t llama_kv_cache_recurrent::size_k_bytes() const {
+    size_t size_k_bytes = 0;
+
+    for (const auto & k : k_l) {
+        size_k_bytes += ggml_nbytes(k);
+    }
+
+    return size_k_bytes;
+}
+
+size_t llama_kv_cache_recurrent::size_v_bytes() const {
+    size_t size_v_bytes = 0;
+
+    for (const auto & v : v_l) {
+        size_v_bytes += ggml_nbytes(v);
+    }
+
+    return size_v_bytes;
+}
+
+void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
+    uint32_t cell_count = 0;
+
+    // Count the number of cells with the specified seq_id
+    // Find all the ranges of cells with this seq id (or all, when -1)
+    uint32_t cell_range_begin = size;
+    for (uint32_t i = 0; i < size; ++i) {
+        const auto & cell = cells[i];
+        if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+            ++cell_count;
+            if (cell_range_begin == size) {
+                cell_range_begin = i;
+            }
+        } else {
+            if (cell_range_begin != size) {
+                cell_ranges.emplace_back(cell_range_begin, i);
+                cell_range_begin = size;
+            }
+        }
+    }
+    if (cell_range_begin != size) {
+        cell_ranges.emplace_back(cell_range_begin, size);
+    }
+
+    // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+    uint32_t cell_count_check = 0;
+    for (const auto & range : cell_ranges) {
+        cell_count_check += range.second - range.first;
+    }
+    GGML_ASSERT(cell_count == cell_count_check);
+
+    io.write(&cell_count, sizeof(cell_count));
+
+    state_write_meta(io, cell_ranges, seq_id);
+    state_write_data(io, cell_ranges);
+}
+
+void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+    uint32_t cell_count;
+    io.read_to(&cell_count, sizeof(cell_count));
+
+    bool res = true;
+    res = res && state_read_meta(io, cell_count, seq_id);
+    res = res && state_read_data(io, cell_count);
+
+    if (!res) {
+        if (seq_id == -1) {
+            clear();
+        } else {
+            seq_rm(seq_id, -1, -1);
+        }
+        throw std::runtime_error("failed to restore kv cache");
+    }
+}
+
+void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
+    for (const auto & range : cell_ranges) {
+        for (uint32_t i = range.first; i < range.second; ++i) {
+            const auto & cell = cells[i];
+            const llama_pos pos      = cell.pos;
+            const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+
+            io.write(&pos,      sizeof(pos));
+            io.write(&n_seq_id, sizeof(n_seq_id));
+
+            if (n_seq_id) {
+                for (auto seq_id : cell.seq_id) {
+                    io.write(&seq_id, sizeof(seq_id));
+                }
+            }
+        }
+    }
+}
+
+void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
+    const uint32_t v_trans = 0;
+    const uint32_t n_layer = hparams.n_layer;
+
+    io.write(&v_trans, sizeof(v_trans));
+    io.write(&n_layer, sizeof(n_layer));
+
+    std::vector<uint8_t> tmp_buf;
+
+    // Iterate and write all the keys first, each row is a cell
+    // Get whole range at a time
+    for (uint32_t il = 0; il < n_layer; ++il) {
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
+        // Write key type
+        const int32_t k_type_i = (int32_t)k_l[il]->type;
+        io.write(&k_type_i, sizeof(k_type_i));
+
+        // Write row size of key
+        const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
+        io.write(&k_size_row, sizeof(k_size_row));
+
+        // Read each range of cells of k_size length each into tmp_buf and write out
+        for (const auto & range : cell_ranges) {
+            const size_t range_size = range.second - range.first;
+            const size_t buf_size = range_size * k_size_row;
+            io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
+        }
+    }
+
+    if (!v_trans) {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Write value type
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            io.write(&v_type_i, sizeof(v_type_i));
+
+            // Write row size of value
+            const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
+            io.write(&v_size_row, sizeof(v_size_row));
+
+            // Read each range of cells of v_size length each into tmp_buf and write out
+            for (const auto & range : cell_ranges) {
+                const size_t range_size = range.second - range.first;
+                const size_t buf_size = range_size * v_size_row;
+                io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
+            }
+        }
+    } else {
+        // When v is transposed, we also need the element size and get the element ranges from each row
+        const uint32_t kv_size = size;
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+            // Write value type
+            const int32_t v_type_i = (int32_t)v_l[il]->type;
+            io.write(&v_type_i, sizeof(v_type_i));
+
+            // Write element size
+            const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
+            io.write(&v_size_el, sizeof(v_size_el));
+
+            // Write GQA embedding size
+            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+
+            // For each row, we get the element values of each cell
+            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                // Read each range of cells of v_size_el length each into tmp_buf and write out
+                for (const auto & range : cell_ranges) {
+                    const size_t range_size = range.second - range.first;
+                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+                    const size_t buf_size = range_size * v_size_el;
+                    io.write_tensor(v_l[il], src_offset, buf_size);
+                }
+            }
+        }
+    }
+}
+
+bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
+    if (dest_seq_id != -1) {
+        // single sequence
+
+        seq_rm(dest_seq_id, -1, -1);
+
+        llama_sbatch sbatch;
+        llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
+
+        batch.n_tokens = cell_count;
+        batch.n_seq_tokens = cell_count;
+        batch.n_seqs = 1;
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_pos pos;
+            uint32_t n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            if (n_seq_id != 0) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+                return false;
+            }
+
+            batch.pos[i] = pos;
+        }
+        batch.n_seq_id[0] = 1;
+        batch.seq_id[0] = &dest_seq_id;
+        if (!find_slot(batch)) {
+            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+            return false;
+        }
+        commit();
+
+        // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
+        // Assume that this is one contiguous block of cells
+        GGML_ASSERT(head + cell_count <= size);
+        GGML_ASSERT(cells[head].pos == batch.pos[0]);
+        GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
+        GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
+    } else {
+        // whole KV cache restore
+
+        if (cell_count > size) {
+            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+            return false;
+        }
+
+        clear();
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            kv_cell & cell = cells[i];
+
+            llama_pos pos;
+            uint32_t  n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            cell.pos = pos;
+
+            for (uint32_t j = 0; j < n_seq_id; ++j) {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+
+                // TODO: llama_kv_cache_recurrent should have a notion of max sequences
+                //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
+                if (seq_id < 0) {
+                    //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
+                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
+                    return false;
+                }
+
+                cell.seq_id.insert(seq_id);
+
+                int32_t & tail = cells[seq_id].tail;
+                if (tail != -1) {
+                    LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
+                    return false;
+                }
+                tail = i;
+            }
+        }
+
+        head = 0;
+        used = cell_count;
+    }
+
+    for (uint32_t i = 0; i < cell_count; ++i) {
+        uint32_t cell_id = head + i;
+        // make sure the recurrent states will keep their restored state
+        cells[cell_id].src = cell_id;
+    }
+
+    return true;
+}
+
+bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
+    uint32_t v_trans;
+    uint32_t n_layer;
+    io.read_to(&v_trans, sizeof(v_trans));
+    io.read_to(&n_layer, sizeof(n_layer));
+
+    if (n_layer != hparams.n_layer) {
+        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
+        return false;
+    }
+    if (cell_count > size) {
+        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
+        return false;
+    }
+    if (false != (bool) v_trans) {
         LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
         return false;
     }
@@ -1326,7 +2432,7 @@ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache
         view->cells_sequences = (llama_seq_id *)p;
     }
 
-    const std::vector<llama_kv_cell> & kv_cells = kvu->cells;
+    const std::vector<llama_kv_cache_unified::kv_cell> & kv_cells = kvu->cells;
     llama_kv_cache_view_cell * c_curr = view->cells;
     llama_seq_id * cs_curr = view->cells_sequences;
     int32_t used_cells = 0;
index 56c74035ae1b9d8d1a12c7ef77cec66938c61861..bf3b4b6a4430f00088d5bbdae82cccffb8778c3e 100644 (file)
@@ -2,32 +2,72 @@
 
 #include "llama.h"
 #include "llama-io.h"
+#include "llama-graph.h"
 #include "llama-memory.h"
 
 #include "ggml-cpp.h"
 
-#include <functional>
 #include <set>
 #include <vector>
 
 struct llama_cparams;
 struct llama_hparams;
 struct llama_ubatch;
+struct llama_sbatch;
+struct llama_model;
+struct llama_context;
 
 struct llama_kv_cache : public llama_memory_i {
-    using llama_memory_i::llama_memory_i;
+    virtual ~llama_kv_cache() = default;
 
-    virtual void restore() = 0; // call if batch processing fails - restores the cache state
-    virtual void commit() = 0;  // call after successful batch processing - clears any pending state
+    // call if batch processing fails - restores the cache state
+    virtual void restore() = 0;
 
-    virtual int32_t get_n_tokens()   const = 0;
-    virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
+    // call after successful batch processing - clears any pending state
+    virtual void commit()  = 0;
 
-    virtual bool get_can_shift() const = 0;
+    // process any pending defrag/shift/etc. operations
+    // optionally call once before processing a new batch
+    virtual bool update(llama_context & lctx) = 0;
+
+    // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
+    virtual void defrag_sched(float thold) = 0;
+
+    // simulate full cache, used for allocating worst-case compute buffers
+    virtual void set_full() = 0;
+
+    //
+    // batch processing
+    //
+
+    virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
+
+    // different KV caches require different batch splitting strategies
+    virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
+
+    // find an empty slot of size "n_tokens" in the cache
+    virtual bool find_slot(const llama_ubatch & batch) = 0;
+
+    // getters
+    virtual int32_t   get_n_tokens()   const = 0;
+    virtual int32_t   get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
+    virtual llama_pos get_pos_max()    const = 0;
+    virtual bool      get_can_shift()  const = 0;
 
     bool get_can_edit() const override { return get_can_shift(); }
+
+    //
+    // state write/read
+    //
+
+    virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
+    virtual void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) = 0;
 };
 
+//
+// llama_kv_cache_guard
+//
+
 struct llama_kv_cache_guard {
     llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
 
@@ -43,65 +83,50 @@ private:
     llama_kv_cache * kv;
 };
 
-struct llama_kv_cell {
-    llama_pos pos   = -1;
-    llama_pos delta =  0;
-    int32_t   src   = -1; // used by recurrent state models to copy states
-    int32_t   tail  = -1;
+//
+// llama_kv_cache_unified
+//
 
-    std::set<llama_seq_id> seq_id;
+// TODO: add notion of max sequences
+class llama_kv_cache_unified : public llama_kv_cache {
+public:
+    struct kv_cell {
+        llama_pos pos   = -1;
+        llama_pos delta =  0;
 
-    bool has_seq_id(const llama_seq_id & id) const {
-        return seq_id.find(id) != seq_id.end();
-    }
+        std::set<llama_seq_id> seq_id;
 
-    bool is_empty() const {
-        return seq_id.empty();
-    }
+        bool has_seq_id(const llama_seq_id & id) const {
+            return seq_id.find(id) != seq_id.end();
+        }
 
-    bool is_same_seq(const llama_kv_cell & other) const {
-        return seq_id == other.seq_id;
-    }
-};
+        bool is_empty() const {
+            return seq_id.empty();
+        }
 
-// ring-buffer of cached KV data
-// TODO: pimpl
-// TODO: add notion of max sequences
-class llama_kv_cache_unified : public llama_kv_cache {
-public:
-    // can be used to query data from the model if needed
-    struct callbacks {
-        std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
+        bool is_same_seq(const kv_cell & other) const {
+            return seq_id == other.seq_id;
+        }
     };
 
-    llama_kv_cache_unified(
-            const llama_hparams & hparams,
-            callbacks             cbs);
-
-    virtual ~llama_kv_cache_unified() = default;
+    static uint32_t get_padding(const llama_cparams & cparams);
 
-    // TODO: become constructor
-    bool init(
-            const llama_model & model,   // TODO: do not reference the model
-          const llama_cparams & cparams,
+    llama_kv_cache_unified(
+            const llama_model & model,
                     ggml_type   type_k,
                     ggml_type   type_v,
+                         bool   v_trans,
+                         bool   offload,
                      uint32_t   kv_size,
-                         bool   offload);
-
-    int32_t get_n_tokens()   const override;
-    int32_t get_used_cells() const override;
+                     uint32_t   padding);
 
-    size_t total_size() const;
+    ~llama_kv_cache_unified() = default;
 
-    // TODO: better data structures to reduce the cost of this operation
-    llama_pos pos_max() const;
+    //
+    // llama_memory_i
+    //
 
     void clear() override;
-    void defrag() override;
-
-    virtual void restore() override;
-    virtual void commit() override;
 
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@@ -111,25 +136,76 @@ public:
 
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
-    bool get_can_shift() const override;
+    //
+    // llama_kv_cache
+    //
+
+    void restore() override;
+    void commit()  override;
+
+    bool update(llama_context & ctx) override;
+
+    void defrag_sched(float thold) override;
+
+    void set_full() override;
+
+    llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
+
+    llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
 
-    // find an empty slot of size "n_tokens" in the cache
     // updates the cache head
     // Note: On success, it's important that cache.head points
     // to the first cell of the slot.
-    bool find_slot(const llama_ubatch & batch);
+    bool find_slot(const llama_ubatch & batch) override;
 
-    // TODO: maybe not needed
-    uint32_t get_padding(const llama_cparams & cparams) const;
+    int32_t get_n_tokens()   const override;
+    int32_t get_used_cells() const override;
 
-    // find how many cells are currently in use
-    uint32_t cell_max() const;
+    // TODO: better data structures to reduce the cost of this operation
+    llama_pos get_pos_max() const override;
 
-    size_t size_k_bytes() const;
-    size_t size_v_bytes() const;
+    bool get_can_shift() const override;
 
-    // defrag
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
+
+    // Note: The value of head isn't only used to optimize searching
+    // for a free KV slot. llama_decode_impl also uses it, so it
+    // cannot be freely changed after a slot has been allocated.
+    uint32_t head = 0;
+    uint32_t size = 0;
+    uint32_t used = 0; // used cells (i.e. at least one seq_id)
+
+    // computed before each graph build
+    uint32_t n = 0;
+
+    std::vector<kv_cell> cells;
+
+    std::vector<ggml_tensor *> k_l; // per layer
+    std::vector<ggml_tensor *> v_l;
+
+private:
+    const llama_model & model;
+    const llama_hparams & hparams;
+
+    bool has_shift = false;
+    bool do_defrag = false;
+
+    bool v_trans   = true;  // the value tensor is transposed
+    bool can_shift = false;
+
+    // required padding
+    uint32_t padding = 1;
+
+    ggml_type type_k = GGML_TYPE_F16;
+    ggml_type type_v = GGML_TYPE_F16;
+
+    std::vector<ggml_context_ptr>        ctxs;
+    std::vector<ggml_backend_buffer_ptr> bufs;
 
+    // defrag
     struct {
         std::vector<uint32_t> ids;
     } defrag_info;
@@ -138,7 +214,6 @@ public:
     bool defrag_prepare(int32_t n_max_nodes);
 
     // commit/restore cache
-
     struct slot_range {
         uint32_t c0 = 0; // note: these are cell indices, not sequence positions
         uint32_t c1 = 0;
@@ -149,25 +224,124 @@ public:
         std::vector<slot_range> ranges;
     } pending;
 
-    // state write/load
+    // find how many cells are currently in use
+    uint32_t cell_max() const;
 
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1);
+    size_t total_size() const;
 
-    // members
+    size_t size_k_bytes() const;
+    size_t size_v_bytes() const;
 
-    const llama_hparams & hparams;
+    ggml_tensor * build_rope_shift(
+            const llama_cparams & cparams,
+                   ggml_context * ctx,
+                    ggml_tensor * cur,
+                    ggml_tensor * shift,
+                    ggml_tensor * factors,
+                          float   freq_base,
+                          float   freq_scale) const;
+
+    llm_graph_result_ptr build_graph_shift(
+            const llama_cparams & cparams,
+                   ggml_context * ctx,
+                    ggml_cgraph * gf) const;
+
+    llm_graph_result_ptr build_graph_defrag(
+            const llama_cparams & cparams,
+                   ggml_context * ctx,
+                    ggml_cgraph * gf) const;
 
-    callbacks cbs;
+    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
 
-    bool has_shift = false;
-    bool do_defrag = false;
+    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
+};
 
-    // TODO: remove this and implement llama_kv_cache_recurrent instead
-    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
+//
+// llama_kv_cache_recurrent
+//
 
-    bool v_trans   = true;  // the value tensor is transposed
-    bool can_shift = false;
+class llama_kv_cache_recurrent : public llama_kv_cache {
+public:
+    struct kv_cell {
+        llama_pos pos  = -1;
+        int32_t   src  = -1; // used to copy states
+        int32_t   tail = -1;
+
+        std::set<llama_seq_id> seq_id;
+
+        bool has_seq_id(const llama_seq_id & id) const {
+            return seq_id.find(id) != seq_id.end();
+        }
+
+        bool is_empty() const {
+            return seq_id.empty();
+        }
+
+        bool is_same_seq(const kv_cell & other) const {
+            return seq_id == other.seq_id;
+        }
+    };
+
+    llama_kv_cache_recurrent(
+            const llama_model & model,
+                    ggml_type   type_k,
+                    ggml_type   type_v,
+                         bool   offload,
+                     uint32_t   kv_size);
+
+    ~llama_kv_cache_recurrent() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    void clear() override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id) override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    //
+    // llama_kv_cache
+    //
+
+    void restore() override;
+    void commit()  override;
+
+    bool update(llama_context & lctx) override;
+
+    void defrag_sched(float thold) override;
+
+    void set_full() override;
+
+    llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
+
+    llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
+
+    bool find_slot(const llama_ubatch & batch) override;
+
+    int32_t get_n_tokens()   const override;
+    int32_t get_used_cells() const override;
+
+    // TODO: better data structures to reduce the cost of this operation
+    llama_pos get_pos_max() const override;
+
+    bool get_can_shift() const override;
+
+    // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
+    int32_t s_copy(int i) const;
+    float   s_mask(int i) const;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
 
     // Note: The value of head isn't only used to optimize searching
     // for a free KV slot. llama_decode_impl also uses it, so it
@@ -179,18 +353,41 @@ public:
     // computed before each graph build
     uint32_t n = 0;
 
-    std::vector<llama_kv_cell> cells;
+    std::vector<kv_cell> cells;
 
     std::vector<ggml_tensor *> k_l; // per layer
     std::vector<ggml_tensor *> v_l;
 
 private:
+    //const llama_model & model;
+    const llama_hparams & hparams;
+
+    // commit/restore cache
+    // TODO: rework for recurrent cache
+    struct slot_range {
+        uint32_t c0 = 0; // note: these are cell indices, not sequence positions
+        uint32_t c1 = 0;
+    };
+
+    // pending cell updates that are not yet committed
+    struct {
+        std::vector<slot_range> ranges;
+    } pending;
+
     ggml_type type_k = GGML_TYPE_F16;
     ggml_type type_v = GGML_TYPE_F16;
 
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
+    // find how many cells are currently in use
+    uint32_t cell_max() const;
+
+    size_t total_size() const;
+
+    size_t size_k_bytes() const;
+    size_t size_v_bytes() const;
+
     void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
     void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
 
@@ -198,11 +395,6 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
-// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
-//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
-//public:
-//    using llama_kv_cache_unified::llama_kv_cache_unified;
-//};
 
 //
 // kv cache view
index dfa8c4e90fc2a90c61c9da74bc84a9d7eafb3fc5..c7412d5911ed79153a5d8fefe7eb24917901630c 100644 (file)
@@ -2,12 +2,22 @@
 
 #include "llama.h"
 
+struct llama_memory_params {
+    // kv cache
+    ggml_type type_k;
+    ggml_type type_v;
+
+    // parameters for other types of memory
+    // ...
+};
+
 // general concept of LLM memory
 // the KV cache is a type of LLM memory, but there can be other types
 class llama_memory_i {
 public:
+    virtual ~llama_memory_i() = default;
+
     virtual void clear() = 0;
-    virtual void defrag() = 0;
 
     virtual bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) = 0;
     virtual void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
index ea73a8a7ba944898f1003fe1ed3fc5ecca53a181..4cce51668b42d09ed155ae421d6a2a618f403d90 100644 (file)
@@ -301,12 +301,12 @@ namespace GGUFMeta {
             GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
 
         switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same<T,  int32_t>::value) ||
-                                            (std::is_same<T, uint32_t>::value));  break;
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,  int32_t>::value) ||
+                                                (std::is_same<T, uint32_t>::value)); break;
+            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,    float>::value)); break;
             default:
-                throw std::runtime_error(format("%s is not a float32int32 array", key.c_str()));
+                throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
         }
 
         result.resize(arr_info.length);
@@ -330,12 +330,12 @@ namespace GGUFMeta {
             GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid);
 
         switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same<T,  int32_t>::value) ||
-                                            (std::is_same<T, uint32_t>::value));  break;
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,  int32_t>::value) ||
+                                                (std::is_same<T, uint32_t>::value)); break;
+            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,    float>::value)); break;
             default:
-                throw std::runtime_error(format("%s is not a float32int32 array", key.c_str()));
+                throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str()));
         }
 
         if (arr_info.length > N_MAX) {
@@ -823,6 +823,10 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps
         mmaps_used.reserve(files.size());
         for (const auto & file : files) {
             auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
+            if (!reg) {
+                throw std::runtime_error(format("%s: no CPU backend found", __func__));
+            }
+
             auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
             std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file.get(), prefetch ? -1 : 0, is_numa_fn());
             mmaps_used.emplace_back(mapping->size(), 0);
diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp
new file mode 100644 (file)
index 0000000..a70b989
--- /dev/null
@@ -0,0 +1,281 @@
+#include "llama-model-saver.h"
+
+#include "gguf.h"
+
+#include "llama.h"
+#include "llama-hparams.h"
+#include "llama-model.h"
+#include "llama-vocab.h"
+
+#include <string>
+
+llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) {
+    gguf_ctx = gguf_init_empty();
+}
+
+llama_model_saver::~llama_model_saver() {
+    gguf_free(gguf_ctx);
+}
+
+void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) {
+    gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value);
+}
+
+void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) {
+    gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value);
+}
+
+void llama_model_saver::add_kv(const enum llm_kv key, const float value) {
+    gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value);
+}
+
+void llama_model_saver::add_kv(const enum llm_kv key, const bool value) {
+    gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value);
+}
+
+void llama_model_saver::add_kv(const enum llm_kv key, const char * value) {
+    gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value);
+}
+
+[[noreturn]]
+void llama_model_saver::add_kv(const enum llm_kv key, const char value) {
+    GGML_UNUSED(key);
+    GGML_UNUSED(value);
+    GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile
+}
+
+template <typename Container>
+void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) {
+    const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size();
+    GGML_ASSERT(n_values <= value.size());
+
+    if (n_values == 0) {
+        return;
+    }
+
+    if (per_layer) {
+        bool all_values_the_same = true;
+        for (size_t i = 1; i < n_values; ++i) {
+            if (value[i] != value[0]) {
+                all_values_the_same = false;
+                break;
+            }
+        }
+        if (all_values_the_same) {
+            add_kv(key, value[0]);
+            return;
+        }
+    }
+
+    if (std::is_same<typename Container::value_type, uint8_t>::value) {
+        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values);
+    } else if (std::is_same<typename Container::value_type, int8_t>::value) {
+        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values);
+    } else if (std::is_same<typename Container::value_type, uint32_t>::value) {
+        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values);
+    } else if (std::is_same<typename Container::value_type, int32_t>::value) {
+        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values);
+    } else if (std::is_same<typename Container::value_type, float>::value) {
+        gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values);
+    } else if (std::is_same<Container, std::string>::value) {
+        gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast<const char *>(value.data()));
+    } else {
+        GGML_ABORT("fatal error");
+    }
+}
+
+void llama_model_saver::add_kv(const enum llm_kv key, const std::vector<std::string> & value) {
+    std::vector<const char *> tmp(value.size());
+    for (size_t i = 0; i < value.size(); ++i) {
+        tmp[i] = value[i].c_str();
+    }
+    gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size());
+}
+
+void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) {
+    if (!tensor) {
+        return;
+    }
+    if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) {
+        GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME
+        return;
+    }
+    gguf_add_tensor(gguf_ctx, tensor);
+}
+
+void llama_model_saver::add_kv_from_model() {
+    const llama_hparams & hparams = model.hparams;
+    const llama_vocab   & vocab   = model.vocab;
+
+    const int32_t n_vocab = vocab.n_tokens();
+    std::vector<std::string> tokens(n_vocab);
+    std::vector<float>       scores(n_vocab);
+    std::vector<int32_t>     token_types(n_vocab);
+
+    for (int32_t id = 0; id < n_vocab; ++id) {
+        const llama_vocab::token_data & token_data = vocab.get_token_data(id);
+
+        tokens[id] = token_data.text;
+        scores[id] = token_data.score;
+
+        switch(token_data.attr) {
+            case LLAMA_TOKEN_ATTR_UNKNOWN:      token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN;      break;
+            case LLAMA_TOKEN_ATTR_UNUSED:       token_types[id] = LLAMA_TOKEN_TYPE_UNUSED;       break;
+            case LLAMA_TOKEN_ATTR_NORMAL:       token_types[id] = LLAMA_TOKEN_TYPE_NORMAL;       break;
+            case LLAMA_TOKEN_ATTR_CONTROL:      token_types[id] = LLAMA_TOKEN_TYPE_CONTROL;      break;
+            case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break;
+            case LLAMA_TOKEN_ATTR_BYTE:         token_types[id] = LLAMA_TOKEN_TYPE_BYTE;         break;
+            case LLAMA_TOKEN_ATTR_UNDEFINED:
+            default:                            token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED;    break;
+        }
+    }
+
+    // add_kv(LLM_KV_GENERAL_TYPE,                      ???);
+    add_kv(LLM_KV_GENERAL_ARCHITECTURE,              model.arch_name());
+    // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION,      ???);
+    // add_kv(LLM_KV_GENERAL_ALIGNMENT,                 ???);
+    add_kv(LLM_KV_GENERAL_NAME,                      model.name);
+    // add_kv(LLM_KV_GENERAL_AUTHOR,                    ???);
+    // add_kv(LLM_KV_GENERAL_VERSION,                   ???);
+    // add_kv(LLM_KV_GENERAL_URL,                       ???);
+    // add_kv(LLM_KV_GENERAL_DESCRIPTION,               ???);
+    // add_kv(LLM_KV_GENERAL_LICENSE,                   ???);
+    // add_kv(LLM_KV_GENERAL_SOURCE_URL,                ???);
+    // add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO,            ???);
+
+    add_kv(LLM_KV_VOCAB_SIZE,                        vocab.n_tokens());
+    add_kv(LLM_KV_CONTEXT_LENGTH,                    hparams.n_ctx_train);
+    add_kv(LLM_KV_EMBEDDING_LENGTH,                  hparams.n_embd);
+    add_kv(LLM_KV_BLOCK_COUNT,                       hparams.n_layer);
+    add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT,         hparams.n_layer_dense_lead);
+    add_kv(LLM_KV_FEED_FORWARD_LENGTH,               hparams.n_ff_arr, true);
+    add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
+    add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+    add_kv(LLM_KV_USE_PARALLEL_RESIDUAL,             hparams.use_par_res);
+    // add_kv(LLM_KV_TENSOR_DATA_LAYOUT,                ???);
+    add_kv(LLM_KV_EXPERT_COUNT,                      hparams.n_expert);
+    add_kv(LLM_KV_EXPERT_USED_COUNT,                 hparams.n_expert_used);
+    add_kv(LLM_KV_EXPERT_SHARED_COUNT,               hparams.n_expert_shared);
+    add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE,              hparams.expert_weights_scale);
+    add_kv(LLM_KV_POOLING_TYPE,                      uint32_t(hparams.pooling_type));
+    add_kv(LLM_KV_LOGIT_SCALE,                       hparams.f_logit_scale);
+    add_kv(LLM_KV_DECODER_START_TOKEN_ID,            hparams.dec_start_token_id);
+    add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING,            hparams.f_attn_logit_softcapping);
+    add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING,           hparams.f_final_logit_softcapping);
+    add_kv(LLM_KV_SWIN_NORM,                         hparams.swin_norm);
+    add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS,            hparams.rescale_every_n_layers);
+    add_kv(LLM_KV_TIME_MIX_EXTRA_DIM,                hparams.time_mix_extra_dim);
+    add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM,              hparams.time_decay_extra_dim);
+    add_kv(LLM_KV_RESIDUAL_SCALE,                    hparams.f_residual_scale);
+    add_kv(LLM_KV_EMBEDDING_SCALE,                   hparams.f_embedding_scale);
+
+    add_kv(LLM_KV_ATTENTION_HEAD_COUNT,              hparams.n_head_arr, true);
+    add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV,           hparams.n_head_kv_arr, true);
+    add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS,          hparams.f_max_alibi_bias);
+    add_kv(LLM_KV_ATTENTION_CLAMP_KQV,               hparams.f_clamp_kqv);
+    add_kv(LLM_KV_ATTENTION_KEY_LENGTH,              hparams.n_embd_head_k);
+    add_kv(LLM_KV_ATTENTION_VALUE_LENGTH,            hparams.n_embd_head_v);
+    add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS,           hparams.f_norm_eps);
+    add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
+    add_kv(LLM_KV_ATTENTION_CAUSAL,                  hparams.causal_attn);
+    add_kv(LLM_KV_ATTENTION_Q_LORA_RANK,             hparams.n_lora_q);
+    add_kv(LLM_KV_ATTENTION_KV_LORA_RANK,            hparams.n_lora_kv);
+    add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,  hparams.n_rel_attn_bkts);
+    add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW,          hparams.n_swa);
+    add_kv(LLM_KV_ATTENTION_SCALE,                   hparams.f_attention_scale);
+
+    const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train;
+
+    add_kv(LLM_KV_ROPE_DIMENSION_COUNT,              hparams.n_rot);
+    add_kv(LLM_KV_ROPE_FREQ_BASE,                    hparams.rope_freq_base_train);
+    // add_kv(LLM_KV_ROPE_SCALE_LINEAR,                 rope_scaling_factor); // old name
+    add_kv(LLM_KV_ROPE_SCALING_TYPE,                 llama_rope_scaling_type_name(hparams.rope_scaling_type_train));
+    add_kv(LLM_KV_ROPE_SCALING_FACTOR,               rope_scaling_factor);
+    add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR,          hparams.rope_attn_factor);
+    add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,         hparams.n_ctx_orig_yarn);
+    add_kv(LLM_KV_ROPE_SCALING_FINETUNED,            hparams.rope_finetuned);
+    add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL,         hparams.rope_yarn_log_mul);
+
+    // TODO: implement split file support
+    // add_kv(LLM_KV_SPLIT_NO,                          ???);
+    // add_kv(LLM_KV_SPLIT_COUNT,                       ???);
+    // add_kv(LLM_KV_SPLIT_TENSORS_COUNT,               ???);
+
+    add_kv(LLM_KV_SSM_INNER_SIZE,                    hparams.ssm_d_inner);
+    add_kv(LLM_KV_SSM_CONV_KERNEL,                   hparams.ssm_d_conv);
+    add_kv(LLM_KV_SSM_STATE_SIZE,                    hparams.ssm_d_state);
+    add_kv(LLM_KV_SSM_TIME_STEP_RANK,                hparams.ssm_dt_rank);
+    add_kv(LLM_KV_SSM_DT_B_C_RMS,                    hparams.ssm_dt_b_c_rms);
+
+    add_kv(LLM_KV_WKV_HEAD_SIZE,                     hparams.wkv_head_size);
+
+    add_kv(LLM_KV_TOKENIZER_MODEL,                   vocab.get_tokenizer_model());
+    add_kv(LLM_KV_TOKENIZER_PRE,                     vocab.get_tokenizer_pre());
+    add_kv(LLM_KV_TOKENIZER_LIST,                    tokens);
+    add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE,              token_types);
+    add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,        vocab.n_token_types());
+    add_kv(LLM_KV_TOKENIZER_SCORES,                  scores);
+    add_kv(LLM_KV_TOKENIZER_MERGES,                  vocab.get_bpe_merges());
+    // FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though
+    add_kv(LLM_KV_TOKENIZER_BOS_ID,                  uint32_t(vocab.token_bos()));
+    add_kv(LLM_KV_TOKENIZER_EOS_ID,                  uint32_t(vocab.token_eos()));
+    add_kv(LLM_KV_TOKENIZER_EOT_ID,                  uint32_t(vocab.token_eot()));
+    add_kv(LLM_KV_TOKENIZER_EOM_ID,                  uint32_t(vocab.token_eom()));
+    add_kv(LLM_KV_TOKENIZER_UNK_ID,                  uint32_t(vocab.token_unk()));
+    add_kv(LLM_KV_TOKENIZER_SEP_ID,                  uint32_t(vocab.token_sep()));
+    add_kv(LLM_KV_TOKENIZER_PAD_ID,                  uint32_t(vocab.token_pad()));
+    // add_kv(LLM_KV_TOKENIZER_CLS_ID,                  uint32_t(vocab.token_bos())); // deprecated
+    // add_kv(LLM_KV_TOKENIZER_MASK_ID,                 ???);
+    add_kv(LLM_KV_TOKENIZER_ADD_BOS,                 vocab.get_add_bos());
+    add_kv(LLM_KV_TOKENIZER_ADD_EOS,                 vocab.get_add_eos());
+    add_kv(LLM_KV_TOKENIZER_ADD_PREFIX,              vocab.get_add_space_prefix());
+    add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,         vocab.get_remove_extra_whitespaces());
+    add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,    vocab.get_precompiled_charsmap());
+    // add_kv(LLM_KV_TOKENIZER_HF_JSON,                 ???);
+    // add_kv(LLM_KV_TOKENIZER_RWKV,                    ???);
+    add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID,              uint32_t(vocab.token_fim_pre()));
+    add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID,              uint32_t(vocab.token_fim_suf()));
+    add_kv(LLM_KV_TOKENIZER_FIM_MID_ID,              uint32_t(vocab.token_fim_mid()));
+    add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID,              uint32_t(vocab.token_fim_pad()));
+    add_kv(LLM_KV_TOKENIZER_FIM_REP_ID,              uint32_t(vocab.token_fim_rep()));
+    add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID,              uint32_t(vocab.token_fim_sep()));
+
+    // TODO: implement LoRA support
+    // add_kv(LLM_KV_ADAPTER_TYPE,                      ???);
+    // add_kv(LLM_KV_ADAPTER_LORA_ALPHA,                ???);
+
+    // deprecated
+    // add_kv(LLM_KV_TOKENIZER_PREFIX_ID,               ???);
+    // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID,               ???);
+    // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID,               ???);
+}
+
+void llama_model_saver::add_tensors_from_model() {
+    if (std::string(model.output->name) != std::string(model.tok_embd->name)) {
+        add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output
+    }
+    add_tensor(model.type_embd);
+    add_tensor(model.pos_embd);
+    add_tensor(model.tok_norm);
+    add_tensor(model.tok_norm_b);
+    add_tensor(model.output_norm);
+    add_tensor(model.output_norm_b);
+    add_tensor(model.output);
+    add_tensor(model.output_b);
+    add_tensor(model.output_norm_enc);
+    add_tensor(model.cls);
+    add_tensor(model.cls_b);
+    add_tensor(model.cls_out);
+    add_tensor(model.cls_out_b);
+
+    for (const struct llama_layer & layer : model.layers) {
+        for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
+            add_tensor(reinterpret_cast<const struct ggml_tensor * const *>(&layer)[i]);
+        }
+    }
+}
+
+void llama_model_saver::save(const std::string & path_model) {
+    gguf_write_to_file(gguf_ctx, path_model.c_str(), false);
+}
+
diff --git a/examples/talk-llama/llama-model-saver.h b/examples/talk-llama/llama-model-saver.h
new file mode 100644 (file)
index 0000000..a5a434c
--- /dev/null
@@ -0,0 +1,37 @@
+#pragma once
+
+#include "llama.h"
+#include "llama-arch.h"
+
+#include <vector>
+
+struct llama_model_saver {
+    struct gguf_context * gguf_ctx = nullptr;
+    const struct llama_model & model;
+    const struct LLM_KV llm_kv;
+
+    llama_model_saver(const struct llama_model & model);
+    ~llama_model_saver();
+
+    void add_kv(enum llm_kv key, uint32_t     value);
+    void add_kv(enum llm_kv key, int32_t      value);
+    void add_kv(enum llm_kv key, float        value);
+    void add_kv(enum llm_kv key, bool         value);
+    void add_kv(enum llm_kv key, const char * value);
+
+    [[noreturn]]
+    void add_kv(enum llm_kv key, char value); // needed to make the template below compile
+
+    template <typename Container>
+    void add_kv(enum llm_kv key, const Container & value, bool per_layer = false);
+
+    void add_kv(enum llm_kv key, const std::vector<std::string> & value);
+
+    void add_tensor(const struct ggml_tensor * tensor);
+
+    void add_kv_from_model();
+
+    void add_tensors_from_model();
+
+    void save(const std::string & path_model);
+};
index 51092a128c5c6e9f216b5ccd3c1d9d555905c671..3a4e72a36b0730417d8dd670b706053aa3f01b88 100644 (file)
@@ -80,6 +80,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_236B:          return "236B";
         case LLM_TYPE_290B:          return "290B";
         case LLM_TYPE_314B:          return "314B";
+        case LLM_TYPE_405B:          return "405B";
         case LLM_TYPE_671B:          return "671B";
         case LLM_TYPE_SMALL:         return "0.1B";
         case LLM_TYPE_MEDIUM:        return "0.4B";
@@ -116,6 +117,10 @@ static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_
     { LLAMA_ROPE_SCALING_TYPE_LONGROPE,   "longrope"   },
 };
 
+std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) {
+    return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type);
+}
+
 static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
     for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
         if (kv.second == name) {
@@ -298,6 +303,10 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
     // add extra buffer types, only if no GPU device is present
     // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
     auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+    if (cpu_dev == nullptr) {
+        throw std::runtime_error(format("%s: no CPU backend found", __func__));
+    }
+
     auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
     auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
         ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
@@ -582,6 +591,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 switch (hparams.n_layer) {
                     case 32: type = LLM_TYPE_7B; break;
                     case 80: type = LLM_TYPE_70B; break;
+                    case 162: type = LLM_TYPE_405B; break;
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
@@ -773,6 +783,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             // fall through
         case LLM_ARCH_QWEN2:
             {
+                ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 switch (hparams.n_layer) {
                     case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break;
@@ -1481,6 +1492,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
     }
 
     ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+    if (cpu_dev == nullptr) {
+        throw std::runtime_error(format("%s: no CPU backend found", __func__));
+    }
     const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
     const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
     auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
@@ -1648,8 +1662,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                 for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
                     std::regex pattern(overrides->pattern);
                     if (std::regex_search(tensor_name, pattern)) {
-                        LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
                         buft = overrides->buft;
+                        LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n",
+                                tensor_name.c_str(),
+                                ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type),
+                                ggml_backend_buft_name(buft));
                         break;
                     }
                 }
@@ -1666,6 +1683,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
             auto * buft_dev = ggml_backend_buft_get_device(buft);
             if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
                 auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+                if (!cpu_dev) {
+                    throw std::runtime_error("no CPU backend found");
+                }
                 buft = ggml_backend_dev_buffer_type(cpu_dev);
             }
 
@@ -1847,7 +1867,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
                         layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        if (n_ff > 0) {
+                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        }
 
                         if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
                             layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
@@ -1857,9 +1879,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                             layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
                         }
 
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        if (n_ff > 0) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        }
 
                         // optional MLP bias
                         layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
@@ -3503,7 +3527,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                     // output
                     output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
@@ -4108,6 +4136,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
         if (!dev) {
             // FIXME: workaround for CPU backend buft having a NULL device
             dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+            if (!dev) {
+                throw std::runtime_error(format("%s: no CPU backend found", __func__));
+            }
         }
         ggml_backend_dev_props props;
         ggml_backend_dev_get_props(dev, &props);
@@ -4237,7 +4268,7 @@ uint64_t llama_model::n_elements() const {
 }
 
 void llama_model::print_info() const {
-    const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
+    const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train);
 
     auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) {
         bool is_var = false;
@@ -4298,7 +4329,7 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
         LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
         LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
-        LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
+        LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type.c_str());
         LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
         LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
         LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
@@ -4445,6 +4476,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
     return it->second;
 }
 
+ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
+    // choose long/short freq factors based on the context size
+    if (layers[il].rope_freqs != nullptr) {
+        return layers[il].rope_freqs;
+    }
+
+    if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
+        return layers[il].rope_long;
+    }
+
+    return layers[il].rope_short;
+}
+
 struct llm_build_llama : public llm_graph_context {
     llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -4485,7 +4529,7 @@ struct llm_build_llama : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4691,6 +4735,7 @@ struct llm_build_deci : public llm_graph_context {
             ggml_tensor * inpSA = inpL;
             const int64_t n_head_kv = hparams.n_head_kv(il);
             const int64_t n_head    = hparams.n_head(il);
+            const int64_t n_ff      = hparams.n_ff(il);
 
             if (n_head == 0) {
                 // attention-free layer of Llama-3_1-Nemotron-51B
@@ -4710,7 +4755,7 @@ struct llm_build_deci : public llm_graph_context {
             } else if (n_head > 0) {
                 // self-attention
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4766,6 +4811,11 @@ struct llm_build_deci : public llm_graph_context {
                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
 
+            // FFN-free layer of Llama-3_1-Nemotron-Ultra-253B
+            if (n_ff == 0) {
+                continue;
+            }
+
             // For Granite architecture
             if (hparams.f_residual_scale) {
                 cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
@@ -7192,7 +7242,7 @@ struct llm_build_phi3 : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for 128k context
-                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
 
                 ggml_tensor* attn_norm_output = build_norm(inpL,
                         model.layers[il].attn_norm,
@@ -7944,7 +7994,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
 
-            ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+            ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
 
             // norm
             cur = build_norm(inpL,
@@ -8711,7 +8761,7 @@ struct llm_build_mamba : public llm_graph_context {
              ggml_tensor * state_mask,
       const llama_ubatch & ubatch,
                      int   il) const {
-        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+        const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
         const auto kv_head = kv_self->head;
 
@@ -9012,7 +9062,7 @@ struct llm_build_cohere2 : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for 128k context
-                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9950,7 +10000,7 @@ struct llm_build_deepseek : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11314,7 +11364,7 @@ struct llm_build_exaone : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11459,7 +11509,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             ggml_tensor * state_mask,
             const llama_ubatch & ubatch,
             int   il) const {
-        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+        const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -11855,7 +11905,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
             ggml_tensor *& first_layer_value,
             const llama_ubatch & ubatch,
             int   il) const {
-        const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
+        const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -12695,7 +12745,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12815,36 +12865,46 @@ struct llm_build_bailingmoe : public llm_graph_context {
     }
 };
 
-llama_memory_i * llama_model::create_memory() const {
+llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
     llama_memory_i * res;
 
     switch (arch) {
+        case LLM_ARCH_BERT:
+        case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_NOMIC_BERT_MOE:
+            {
+                res = nullptr;
+            } break;
         case LLM_ARCH_MAMBA:
         case LLM_ARCH_RWKV6:
         case LLM_ARCH_RWKV6QWEN2:
         case LLM_ARCH_RWKV7:
         case LLM_ARCH_ARWKV7:
             {
-                res = new llama_kv_cache_unified(hparams, {
-                    /*.get_rope_factors =*/ nullptr
-                });
+                res = new llama_kv_cache_recurrent(
+                        *this,
+                        GGML_TYPE_F32,
+                        GGML_TYPE_F32,
+                        cparams.offload_kqv,
+                        std::max((uint32_t) 1, cparams.n_seq_max));
             } break;
         default:
             {
-                res = new llama_kv_cache_unified(hparams, {
-                    /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
-                        // choose long/short freq factors based on the context size
-                        if (layers[il].rope_freqs != nullptr) {
-                            return layers[il].rope_freqs;
-                        }
+                const auto padding = llama_kv_cache_unified::get_padding(cparams);
 
-                        if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
-                            return layers[il].rope_long;
-                        }
+                cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
 
-                        return layers[il].rope_short;
-                    }
-                });
+                LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
+
+                res = new llama_kv_cache_unified(
+                        *this,
+                        params.type_k,
+                        params.type_v,
+                        !cparams.flash_attn,
+                        cparams.offload_kqv,
+                        cparams.n_ctx,
+                        padding);
             }
     }
 
@@ -13226,8 +13286,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_DECI:
         case LLM_ARCH_BAICHUAN:
         case LLM_ARCH_STARCODER:
-        case LLM_ARCH_PLAMO:
-        case LLM_ARCH_ORION:
         case LLM_ARCH_INTERNLM2:
         case LLM_ARCH_MINICPM:
         case LLM_ARCH_XVERSE:
@@ -13265,6 +13323,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI3:
         case LLM_ARCH_PHIMOE:
+        case LLM_ARCH_PLAMO:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_GEMMA3:
@@ -13272,6 +13331,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_OPENELM:
         case LLM_ARCH_GPTNEOX:
         case LLM_ARCH_CODESHELL:
+        case LLM_ARCH_ORION:
         case LLM_ARCH_NEMOTRON:
         case LLM_ARCH_EXAONE:
         case LLM_ARCH_MINICPM3:
@@ -13344,6 +13404,14 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
         : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
     const auto & it = model->gguf_kv.find(key);
     if (it == model->gguf_kv.end()) {
+        // one-off fix for very popular models (so we are not flooded with issues)
+        // do not extend this list unless absolutely necessary
+        // Mistral-Small-2503 does not have built-in chat template
+        llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
+        if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
+            return "mistral-v7-tekken";
+        }
+
         return nullptr;
     }
 
index 34aac337cff2769c085e5a9157a7cf58b76e7b83..6bdec263b709b2b027db73799aaa71b5f7326225 100644 (file)
@@ -76,6 +76,7 @@ enum llm_type {
     LLM_TYPE_236B,
     LLM_TYPE_290B,
     LLM_TYPE_314B,
+    LLM_TYPE_405B,
     LLM_TYPE_671B,
     LLM_TYPE_SMALL,
     LLM_TYPE_MEDIUM,
@@ -95,6 +96,8 @@ enum llm_type {
     LLM_TYPE_235B_A22B,
 };
 
+std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
+
 struct llama_layer_posnet {
     // resnet
     struct ggml_tensor * norm1   = nullptr;
@@ -395,8 +398,11 @@ struct llama_model {
 
     const struct ggml_tensor * get_tensor(const char * name) const;
 
+    ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
+
+    // note: can mutate `cparams`
     // TODO: move this to new llm_arch_model_i interface
-    llama_memory_i * create_memory() const; // TODO: params
+    llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
 
     // TODO: move this to new llm_arch_model_i interface
     llm_graph_result_ptr build_graph(
index 7dc54227631180c1d4ee03f77631eddddca3f851..820d5128e29ba700702ef31447be9ddeb7c52d53 100644 (file)
@@ -519,7 +519,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         nthread = std::thread::hardware_concurrency();
     }
 
-    // mmap consistently increases speed Linux, and also increases speed on Windows with
+    // mmap consistently increases speed on Linux, and also increases speed on Windows with
     // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
 #if defined(__linux__) || defined(_WIN32)
     constexpr bool use_mmap = true;
@@ -529,7 +529,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
 
     llama_model_kv_override * kv_overrides = nullptr;
     if (params->kv_overrides) {
-        auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
+        auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
         kv_overrides = v->data();
     }
 
index c0a5f9340d5851beade2deb00ef41bea2e38bc7e..804b11e0a943e9625c78516c5da629ec91261968 100644 (file)
@@ -1750,23 +1750,35 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler *
 static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
     const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
 
+    if (ctx->n <= 0.0f || cur_p->size <= 1) {
+        return;
+    }
+
     // find max logit and calculate mean
     float max = cur_p->data[0].logit;
     float logits_sum = 0;
+    size_t valid_count = 0;
     for (size_t i = 0; i < cur_p->size; ++i) {
-        if (cur_p->data[i].logit > max) {
-            max = cur_p->data[i].logit;
+        // Only count non-negative infinity values
+        if (cur_p->data[i].logit != -INFINITY) {
+            if (cur_p->data[i].logit > max) {
+                max = cur_p->data[i].logit;
+            }
+            logits_sum += cur_p->data[i].logit;
+            valid_count++;
         }
-        logits_sum += cur_p->data[i].logit;
     }
-    float mean = logits_sum/cur_p->size;
+    float mean = valid_count > 0 ? logits_sum/valid_count : 0;
 
     // calculate standard deviation
     float acc = 0;
     for (size_t i = 0; i < cur_p->size; ++i) {
-        acc += pow(cur_p->data[i].logit - mean, 2);
+        // Skip -infinity in std calculation
+        if (cur_p->data[i].logit != -INFINITY) {
+            acc += pow(cur_p->data[i].logit - mean, 2);
+        }
     }
-    float std = sqrt(acc/cur_p->size);
+    float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
 
     //apply mask
     for (size_t i = 0; i < cur_p->size; ++i) {
index 50ded286f3f5f989afec68e21d2bf81d567839db..9389ca805a584fadcb9110b0841f95234907b8b7 100644 (file)
@@ -1,5 +1,7 @@
 #include "llama-vocab.h"
 
+#include "ggml.h"
+#include "gguf.h"
 #include "llama-impl.h"
 #include "llama-model-loader.h"
 
@@ -415,6 +417,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_SEED_CODER:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
             default:
                 // default regex for BPE tokenization pre-processing
                 regex_exprs = {
@@ -1227,6 +1236,9 @@ struct fragment_buffer_variant {
 struct llama_vocab::impl {
     uint32_t n_token_types = 0; // for BERT-style token types
 
+    std::string tokenizer_model;
+    std::string tokenizer_pre;
+
     enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
     enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
 
@@ -1362,9 +1374,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
 
     // determine vocab type
     {
-        std::string tokenizer_model;
-        std::string tokenizer_pre;
-
         ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
         ml.get_key(LLM_KV_TOKENIZER_PRE,   tokenizer_pre, false);
 
@@ -1459,7 +1468,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
 
             const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
             if (precompiled_charsmap_keyidx != -1) {
-                size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
+                const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx);
+                GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8);
+
+                const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
                 const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
                 precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
 #ifdef IS_BIG_ENDIAN
@@ -1634,6 +1646,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 tokenizer_pre == "bailingmoe") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE;
                 clean_spaces = false;
+            } else if (
+                tokenizer_pre == "seed-coder") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
+                clean_spaces = false;
             } else {
                 throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
             }
@@ -2778,6 +2794,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
     pimpl->load(ml, kv);
 }
 
+std::string llama_vocab::get_tokenizer_model() const {
+    return pimpl->tokenizer_model;
+}
+
+std::string llama_vocab::get_tokenizer_pre() const {
+    return pimpl->tokenizer_pre;
+}
+
 enum llama_vocab_type llama_vocab::get_type() const {
     return pimpl->type;
 }
@@ -3000,6 +3024,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
     return it->second;
 }
 
+std::vector<std::string> llama_vocab::get_bpe_merges() const {
+    std::vector<std::string> result(pimpl->bpe_ranks.size());
+
+    for (const auto & pair : pimpl->bpe_ranks) {
+        result[pair.second] = pair.first.first + " " + pair.first.second;
+    }
+
+    return result;
+}
+
+std::vector<char> llama_vocab::get_precompiled_charsmap() const {
+    return pimpl->precompiled_charsmap;
+}
+
 int32_t llama_vocab::tokenize(
                   const char * text,
                      int32_t   text_len,
index 5ce355214346f12313c40f689ebfe3aa4b8d1a12..daa6cf3082f90a3dc1ace5fff4b379bd3220c51e 100644 (file)
@@ -21,6 +21,9 @@ struct llama_vocab {
 
     void load(llama_model_loader & ml, const LLM_KV & kv);
 
+    std::string get_tokenizer_model() const;
+    std::string get_tokenizer_pre() const;
+
     enum llama_vocab_type     get_type()     const;
     enum llama_vocab_pre_type get_pre_type() const;
 
@@ -80,6 +83,9 @@ struct llama_vocab {
     int max_token_len() const;
 
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
+    std::vector<std::string> get_bpe_merges() const;
+
+    std::vector<char> get_precompiled_charsmap() const;
 
     int32_t tokenize(
                    const char * text,
index c52bbe5f2804f918afb0f49726c54497a998d670..9fdddf7b071f83925ae7d4662046fa6340f91fd9 100644 (file)
@@ -4,6 +4,7 @@
 #include "llama-mmap.h"
 #include "llama-vocab.h"
 #include "llama-model-loader.h"
+#include "llama-model-saver.h"
 #include "llama-model.h"
 
 #include "ggml.h"
 #include <cstring>
 #include <ctime>
 
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
 //
 // interface implementation
 //
@@ -249,6 +254,13 @@ struct llama_model * llama_model_load_from_splits(
     return llama_model_load_from_file_impl(splits.front(), splits, params);
 }
 
+void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
+    llama_model_saver ms(*model);
+    ms.add_kv_from_model();
+    ms.add_tensors_from_model();
+    ms.save(path_model);
+}
+
 //
 // chat templates
 //
@@ -334,3 +346,4 @@ const char * llama_print_system_info(void) {
 
     return s.c_str();
 }
+
index 06c56395c139fb8cc936542cae2a124c16cfba71..99e5fba244fcc2902fa9e97a085392820ee3975e 100644 (file)
@@ -4,6 +4,7 @@
 #include "ggml.h"
 #include "ggml-cpu.h"
 #include "ggml-backend.h"
+#include "ggml-opt.h"
 
 #include <stddef.h>
 #include <stdint.h>
@@ -112,6 +113,7 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_BAILINGMOE     = 32,
         LLAMA_VOCAB_PRE_TYPE_LLAMA4         = 33,
         LLAMA_VOCAB_PRE_TYPE_PIXTRAL        = 34,
+        LLAMA_VOCAB_PRE_TYPE_SEED_CODER     = 35,
     };
 
     enum llama_rope_type {
@@ -343,7 +345,7 @@ extern "C" {
         float    yarn_beta_fast;   // YaRN low correction dim
         float    yarn_beta_slow;   // YaRN high correction dim
         uint32_t yarn_orig_ctx;    // YaRN original context size
-        float    defrag_thold;     // defragment the KV cache if holes/size > thold, < 0 disabled (default)
+        float    defrag_thold;     // defragment the KV cache if holes/size > thold, <= 0 disabled (default)
 
         ggml_backend_sched_eval_callback cb_eval;
         void * cb_eval_user_data;
@@ -351,19 +353,18 @@ extern "C" {
         enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
         enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
 
-        // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
-        // TODO: move at the end of the struct
-        bool logits_all;  // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
-        bool embeddings;  // if true, extract embeddings (together with logits)
-        bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
-        bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
-        bool no_perf;     // whether to measure performance timings
-
         // Abort callback
         // if it returns true, execution of llama_decode() will be aborted
         // currently works only with CPU execution
         ggml_abort_callback abort_callback;
         void *              abort_callback_data;
+
+        // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
+        bool embeddings;  // if true, extract embeddings (together with logits)
+        bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
+        bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
+        bool no_perf;     // whether to measure performance timings
+        bool op_offload;  // whether to offload host tensor operations to device
     };
 
     // model quantization parameters
@@ -445,6 +446,10 @@ extern "C" {
                                  size_t    n_paths,
               struct llama_model_params    params);
 
+    LLAMA_API void llama_model_save_to_file(
+            const struct llama_model * model,
+                        const char * path_model);
+
     DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
             "use llama_model_free instead");
 
@@ -924,14 +929,19 @@ extern "C" {
     // Frees a batch of tokens allocated with llama_batch_init()
     LLAMA_API void llama_batch_free(struct llama_batch batch);
 
-    // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
-    // Stores the encoder output internally for later use by the decoder cross-attention layers.
+    // Process a batch of tokens.
+    // In contrast to llama_decode() - this call does not use KV cache.
+    // For encode-decoder contexts, processes the batch using the encoder.
+    // Can store the encoder output internally for later use by the decoder's cross-attention layers.
     //   0 - success
     // < 0 - error. the KV cache state is restored to the state before this call
     LLAMA_API int32_t llama_encode(
             struct llama_context * ctx,
               struct llama_batch   batch);
 
+    // Process a batch of tokens.
+    // Requires KV cache.
+    // For encode-decoder contexts, processes the batch using the decoder.
     // Positive return values does not mean a fatal error, but rather a warning.
     //   0 - success
     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
@@ -1428,6 +1438,37 @@ extern "C" {
     LLAMA_API void                           llama_perf_sampler_print(const struct llama_sampler * chain);
     LLAMA_API void                           llama_perf_sampler_reset(      struct llama_sampler * chain);
 
+    //
+    // training
+    //
+
+    // function that returns whether or not a given tensor contains trainable parameters
+    typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
+
+    // always returns true
+    LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
+
+    struct llama_opt_params {
+        uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
+
+        llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
+        void * param_filter_ud;              // userdata for determining which tensors contain trainable parameters
+
+        ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
+        void * get_opt_pars_ud;                     // userdata for calculating optimizer parameters
+    };
+
+    LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
+
+    LLAMA_API void llama_opt_epoch(
+            struct llama_context    * lctx,
+            ggml_opt_dataset_t        dataset,
+            ggml_opt_result_t         result_train,
+            ggml_opt_result_t         result_eval,
+            int64_t                   idata_split,
+            ggml_opt_epoch_callback   callback_train,
+            ggml_opt_epoch_callback   callback_eval);
+
 #ifdef __cplusplus
 }
 #endif