]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
memory : Hybrid recurrent cache (#13979)
authorGabe Goodhart <redacted>
Thu, 19 Jun 2025 05:08:14 +0000 (00:08 -0500)
committerGitHub <redacted>
Thu, 19 Jun 2025 05:08:14 +0000 (08:08 +0300)
* feat: Add llama_model_is_hybrid API call

Also, split llama_model_is_recurrent into llm_arch_is_recurrent in
llama-arch with llama_model_is_recurrent delegating to
llm_arch_is_recurrent. The same split is done for hybird. This is needed
because there are places where the llama_model has not yet been initialized
but we need to check if the model is recurrent (specifically for the
per-layer recurrent check array in hparams).

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <redacted>
* feat: Add c++ side constants for attention layer indices hparam

Branch: GraniteFour

* feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <redacted>
* feat: Auto-fill hparams.recurrent_layer_arr based on whether the model is recurrent

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <redacted>
* refactor: rename *_is_hybrid -> *_is_hybrid_recurrent

The implementation of the hybrid cache intentionally does not specify the
types of the child caches, so there was a naming mismatch with these
predicate functions that used "hybrid" to imply "hybrid recurrent."

Branch: HybridCache

Signed-off-by: Gabe Goodhart <redacted>
* feat: Add layer filter to recurrent cache

Branch: HybridCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Use per-layer sizing everywhere in kv caches

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <redacted>
* feat: First pass at llama_kv_cache_hybrid_recurrent

This follows the pattern in iswa where the two child caches are held
explicitly to support the case where a model requires a single attention
cache and a single recurrent cache where each layer uses exactly one of the
caches.

This is a rewrite of the more generic approach in the original hybrid cache
PR: https://github.com/ggml-org/llama.cpp/pull/13276

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* feat: Construct hybrid recurrent cache for hybrid recurrent models

This includes a refactor of the create_memory logic to avoid needing to use
the arch enum explicitly unless a model needs explicit cache instantiation
logic beyond the standard logic for recurrent, hybrid, unified, and iswa.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Fix wrong bool condition for split equal in hybrid cache

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Fix shift logic to defer to unified cache

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* feat: Support hybrid recurrent in llama-graph

NOTE: I intentionally did not add support for s_mask since it will be going
away soon

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Fix logic for initializing inputs and attn layers for hybrid caches

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <redacted>
* fix: Update recurrent cache for changes to remove intermediate kv_cache interface

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Fix status for init_update sig for recurrent cache state

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <redacted>
* fix: Add missing padding to n_ctx for hybrid cache construction

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <redacted>
* fix: Update clear signature for data argument after rebase

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Remove errant virtual destructor leftover from previous impl attempt

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* refactor: Remove n_embd_k/v_s from unified cache

No longer needed now that unified isn't also supporting recurrent

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140761069

Branch: HybridRecurrentCache

* refactor: Remove layer index from n_embd_k/v_s

Now that it's not used at all in the unified cache, we don't need to use
the layer index to zero it out for attention layers.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* refactor: Remove n_embd_k/v_gqa from recurrent cache

This is no longer needed now that there are separate implementations

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140825128

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* feat: Allow custom layer filters for hybrid recurrent

This should help support architectures like Falcon H1 where there is
overlap between layers that need attention and recurrent caches.

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Remove logits_all after rebase

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Remove llama_model_is_hybrid_Recurrent public API

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2141728423

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* refactor: Use llama_memory_state_ptr for child states in hybrid memory state

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern

https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738

This is a big overhaul to bring consistency between how inputs and per-
layer components are created for attention layers and recurrent layers. The
main changes are:

- Rename class llm_graph_input_s_copy -> llm_graph_input_rs
- Add a corresponding llm_graph_input_rs_hybrid_recurrent
- Rename build_inp_s_copy -> build_rs_inp_recurrent
- Add a corresponding build_rs_inp_hybrid_recurrent
- Rename build_recurrent_state -> build_rs to match build_attn w/
llm_graph_input_rs android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input
- Add a corresponding overload of build_rs w/
llm_graph_input_rs_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input
- Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to
llm_graph_input_attn_kv_unified
- Add a build_attn override that takes
llm_graph_input_attn_kv_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input

This makes the two paradigms fully consistent. The main drawback is the
code duplication in the build_attn and build_rs implementations where the
only difference between implementations is how they cast the memory state.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* fix: Fix resize vs reserve and skip null tensors in size computation

https://github.com/ggml-org/llama.cpp/pull/13979/files#r2149469788

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
Co-Authored-By: @younesbelkada
* fix: Fix initialization of child states

Since initially writing this PR, the logic in the child state types changed
such that using the "init full" signature and keeping the ubatches on the
parent struct no longer worked.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* refactor: Use a common build_recurrent_state method that is cache-agnostic

This reduces the code duplication between the different build_rs impls and
also retains a similar signature to the previous build_recurrent_state
method while standardizing on the input-dispatched build_rs implementation.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* recurrent : rework graph inputs + add TODOs

ggml-ci

* refactor: Make status and child states const in hybrid and iswa

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* refactor: Rename llama_kv_cache_[recurrent|hybrid_recurrent] to remove kv cache

This removes the notion of "kv" from the interface names for these memory
types. There are still many references to kv in the implementation of the
recurrent memory which will need further adjustment.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* refactor!: Rename all k/v related values for recurrent/hybrid to r/s

Anywhere that "kv_<state|cell|size|etc>" is used, I've used the more
generic "mem_" prefix. The specifics of "k" (key) translate to "r"
(recurrent state) and "v" (value) translate to "s" (state-space embedding
states).

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* refacor: _recurrent -> _recr for brevity

It just _happens_ to have the same number of letters as _attn!

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* style: Fix spacing for ref

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* refactor: recurrent_layer() -> is_recurrent()

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <redacted>
* style: Fix spacing for size_s_bytes declaration

Co-authored-by: Georgi Gerganov <redacted>
---------

Signed-off-by: Gabe Goodhart <redacted>
Co-authored-by: Georgi Gerganov <redacted>
17 files changed:
src/CMakeLists.txt
src/llama-arch.cpp
src/llama-arch.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-kv-cache-recurrent.cpp [deleted file]
src/llama-kv-cache-recurrent.h [deleted file]
src/llama-kv-cache-unified-iswa.cpp
src/llama-kv-cache-unified-iswa.h
src/llama-kv-cache-unified.cpp
src/llama-memory-hybrid.cpp [new file with mode: 0644]
src/llama-memory-hybrid.h [new file with mode: 0644]
src/llama-memory-recurrent.cpp [new file with mode: 0644]
src/llama-memory-recurrent.h [new file with mode: 0644]
src/llama-model.cpp

index 70be604e4b0d336a30b24d9cf0601cde62b91a91..8f9cd652447abe72c07cf1699dfc9aaad38fd9ec 100644 (file)
@@ -22,8 +22,9 @@ add_library(llama
             llama-io.cpp
             llama-kv-cache-unified.cpp
             llama-kv-cache-unified-iswa.cpp
-            llama-kv-cache-recurrent.cpp
             llama-memory.cpp
+            llama-memory-hybrid.cpp
+            llama-memory-recurrent.cpp
             llama-mmap.cpp
             llama-model-loader.cpp
             llama-model-saver.cpp
index de8d289cf967e989cbd3fcca85639b1b11c4ba32..0bc60565df12ca77249d6449698f53836f53cfc0 100644 (file)
@@ -147,6 +147,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
     { LLM_KV_ATTENTION_KEY_LENGTH_MLA,               "%s.attention.key_length_mla"               },
     { LLM_KV_ATTENTION_VALUE_LENGTH_MLA,             "%s.attention.value_length_mla"             },
+    { LLM_KV_ATTENTION_LAYER_INDICES,                "%s.attention.layer_indices"                },
 
     { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
@@ -1816,3 +1817,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
 const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
     return LLM_TENSOR_INFOS.at(tensor);
 }
+
+bool llm_arch_is_recurrent(const llm_arch & arch) {
+    switch (arch) {
+        case LLM_ARCH_MAMBA:
+        case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
+        case LLM_ARCH_RWKV7:
+        case LLM_ARCH_ARWKV7:
+            return true;
+        default:
+            return false;
+    }
+}
+
+bool llm_arch_is_hybrid(const llm_arch & arch) {
+    // TODO: There are currently no hybrid models! Once there are, this will be
+    //  the place to identify them
+    switch (arch) {
+        default:
+            return false;
+    }
+}
index 3e8a61da3c13e38fc1711e003d2f2b6ba3f59393..51b242c66b824a93498c420156d751144317b019 100644 (file)
@@ -151,6 +151,7 @@ enum llm_kv {
     LLM_KV_ATTENTION_SCALE,
     LLM_KV_ATTENTION_KEY_LENGTH_MLA,
     LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
+    LLM_KV_ATTENTION_LAYER_INDICES,
 
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -439,3 +440,6 @@ const char * llm_arch_name(llm_arch arch);
 llm_arch llm_arch_from_string(const std::string & name);
 
 const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
+
+bool llm_arch_is_recurrent(const llm_arch & arch);
+bool llm_arch_is_hybrid   (const llm_arch & arch);
index 337fb5cb0df3634d00284c64429dce9141c641a9..65d98cbbb3987036a27180d61f8f7083114d1455 100644 (file)
@@ -6,7 +6,8 @@
 
 #include "llama-kv-cache-unified.h"
 #include "llama-kv-cache-unified-iswa.h"
-#include "llama-kv-cache-recurrent.h"
+#include "llama-memory-hybrid.h"
+#include "llama-memory-recurrent.h"
 
 #include <cassert>
 #include <cmath>
@@ -238,18 +239,18 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
+void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
-    const int64_t n_kv = kv_state->get_n_kv();
+    const int64_t n_rs = mem_state->get_n_rs();
 
     if (s_copy) {
         GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
         int32_t * data = (int32_t *) s_copy->data;
 
         // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
-        for (uint32_t i = 0; i < n_kv; ++i) {
-            data[i] = kv_state->s_copy(i);
+        for (uint32_t i = 0; i < n_rs; ++i) {
+            data[i] = mem_state->s_copy(i);
         }
     }
 }
@@ -403,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
     }
 }
 
+void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
+    if (self_kq_mask) {
+        mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+    }
+
+    const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
+
+    if (s_copy) {
+        GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
+        int32_t * data = (int32_t *) s_copy->data;
+
+        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+        for (uint32_t i = 0; i < n_rs; ++i) {
+            data[i] = mem_state->get_state_recr()->s_copy(i);
+        }
+    }
+}
+
 //
 // llm_graph_context
 //
@@ -961,23 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
     return cur;
 }
 
-ggml_tensor * llm_graph_context::build_inp_s_copy() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
-
-    auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
-
-    const auto n_kv = kv_state->get_n_kv();
-
-    auto & cur = inp->s_copy;
-
-    cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
-    ggml_set_input(cur);
-
-    res->add_input(std::move(inp));
-
-    return cur;
-}
-
 ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
     auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
 
@@ -1047,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
     return pos_bias;
 }
 
+llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
+    const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
+
+    auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
+
+    {
+        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
+
+        const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
+
+        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask, "KQ_mask", -1);
+        ggml_set_input(inp->self_kq_mask);
+
+        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+    }
+
+    {
+        const auto n_rs = mem_state->get_state_recr()->get_n_rs();
+
+        inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
+        ggml_set_input(inp->s_copy);
+    }
+
+    return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
+}
+
 ggml_tensor * llm_graph_context::build_attn_mha(
          ggml_cgraph * gf,
          ggml_tensor * q,
@@ -1291,36 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
-
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
-
-    {
-        const auto n_kv = kv_state->get_base()->get_n_kv();
-
-        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask, "KQ_mask", -1);
-        ggml_set_input(inp->self_kq_mask);
-
-        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
-    }
-
-    {
-        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
-
-        const auto n_kv = kv_state->get_swa()->get_n_kv();
-
-        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
-        ggml_set_input(inp->self_kq_mask_swa);
-
-        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
-    }
-
-    return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
-}
-
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_kv_unified_iswa * inp,
         ggml_cgraph * gf,
@@ -1430,20 +1429,99 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-ggml_tensor * llm_graph_context::build_recurrent_state(
-         ggml_cgraph * gf,
-         ggml_tensor * s,
-         ggml_tensor * state_copy,
-             int32_t   state_size,
-             int32_t   n_seqs,
-                bool   avoid_copies) const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
-
-    const auto n_kv    = kv_state->get_n_kv();
-    const auto kv_head = kv_state->get_head();
-    const auto rs_zero = kv_state->get_rs_z();
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_mem_hybrid * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+        ggml_tensor * v_mla,
+            float     kq_scale,
+            int       il) const {
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
+    const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
+
+    // store to KV cache
+    {
+        ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
+    }
+
+    const auto & kq_mask = inp->get_kq_mask();
+
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = kv_state->get_k(ctx0, il);
+    ggml_tensor * v = kv_state->get_v(ctx0, il);
+
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    cb(cur, "kqv_out", il);
+
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+        if (arch == LLM_ARCH_GLM4) {
+            // GLM4 seems to have numerical issues with half-precision accumulators
+            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
+        }
+    }
+
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
+
+    return cur;
+}
+
+llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
+    const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
+
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
 
-    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
+    {
+        const auto n_kv = kv_state->get_base()->get_n_kv();
+
+        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask, "KQ_mask", -1);
+        ggml_set_input(inp->self_kq_mask);
+
+        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+    }
+
+    {
+        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
+
+        const auto n_kv = kv_state->get_swa()->get_n_kv();
+
+        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
+        ggml_set_input(inp->self_kq_mask_swa);
+
+        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
+    }
+
+    return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_rs(
+        ggml_cgraph * gf,
+        ggml_tensor * s,
+        ggml_tensor * state_copy,
+            int32_t   state_size,
+            int32_t   n_seqs,
+           uint32_t   n_kv,
+           uint32_t   kv_head,
+           uint32_t   kv_size,
+            int32_t   rs_zero,
+               bool   avoid_copies) const {
+
+    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
 
     // Clear a single state which will then be copied to the other cleared states.
     // Note that this is a no-op when the view is zero-sized.
@@ -1474,22 +1552,59 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
     return output_states;
 }
 
+llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
+    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+
+    auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
+
+    const auto n_rs = kv_state->get_n_rs();
+
+    inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
+    ggml_set_input(inp->s_copy);
+
+    return (llm_graph_input_rs *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_rs(
+        llm_graph_input_rs * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * s,
+            int32_t   state_size,
+            int32_t   n_seqs,
+               bool   avoid_copies) const {
+    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
+
+    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
+}
+
+ggml_tensor * llm_graph_context::build_rs(
+        llm_graph_input_mem_hybrid * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * s,
+            int32_t   state_size,
+            int32_t   n_seqs,
+               bool   avoid_copies) const {
+    const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
+
+    return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
+}
+
 ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
-         ggml_cgraph * gf,
-         ggml_tensor * state_copy,
-  const llama_ubatch & ubatch,
+    llm_graph_input_rs * inp,
+           ggml_cgraph * gf,
+    const llama_ubatch & ubatch,
                  int   il) const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
     const auto token_shift_count = hparams.token_shift_count;
 
     const int64_t n_seqs  = ubatch.n_seqs;
 
-    ggml_tensor * token_shift_all = kv_state->get_k_l(il);
+    ggml_tensor * token_shift_all = kv_state->get_r_l(il);
 
-    ggml_tensor * token_shift = build_recurrent_state(
-            gf, token_shift_all, state_copy,
-            hparams.n_embd_k_s(), n_seqs);
+    ggml_tensor * token_shift = build_rs(
+            inp, gf, token_shift_all,
+            hparams.n_embd_r(), n_seqs);
 
     token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
 
@@ -1500,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
          ggml_tensor * token_shift,
   const llama_ubatch & ubatch,
                  int   il) const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+    const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
     const auto token_shift_count = hparams.token_shift_count;
     const auto n_embd = hparams.n_embd;
@@ -1512,7 +1627,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
     return ggml_cpy(
         ctx0,
         ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
-        ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
+        ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
     );
 }
 
index 87813119b1a3cf5656bec8b837df18628e7902a5..58845e284abed87f06e4cb5f15035a42a521f316 100644 (file)
@@ -21,7 +21,8 @@ struct llama_memory_state_i;
 
 class llama_kv_cache_unified_state;
 class llama_kv_cache_unified_iswa_state;
-class llama_kv_cache_recurrent_state;
+class llama_memory_recurrent_state;
+class llama_memory_hybrid_state;
 
 // certain models (typically multi-modal) can produce different types of graphs
 enum llm_graph_type {
@@ -188,16 +189,16 @@ public:
     const llama_cparams & cparams;
 };
 
-class llm_graph_input_s_copy : public llm_graph_input_i {
+class llm_graph_input_rs : public llm_graph_input_i {
 public:
-    llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
-    virtual ~llm_graph_input_s_copy() = default;
+    llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
+    virtual ~llm_graph_input_rs() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * s_copy; // I32 [kv_size]
 
-    const llama_kv_cache_recurrent_state * kv_state;
+    const llama_memory_recurrent_state * mem_state;
 };
 
 class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -300,6 +301,33 @@ public:
     const llama_cross * cross = nullptr;
 };
 
+class llm_graph_input_mem_hybrid : public llm_graph_input_i {
+public:
+    llm_graph_input_mem_hybrid(
+            const llama_hparams & hparams,
+            const llama_cparams & cparams,
+            const llama_memory_hybrid_state * mem_state) :
+        hparams(hparams),
+        cparams(cparams),
+        mem_state(mem_state) {
+    }
+    virtual ~llm_graph_input_mem_hybrid() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    ggml_tensor * s_copy; // I32 [kv_size]
+
+    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
+
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+
+    const llama_memory_hybrid_state * mem_state;
+};
+
 //
 // llm_graph_result
 //
@@ -508,13 +536,14 @@ struct llm_graph_context {
     ggml_tensor * build_inp_out_ids() const;
     ggml_tensor * build_inp_mean() const;
     ggml_tensor * build_inp_cls() const;
-    ggml_tensor * build_inp_s_copy() const;
 
     ggml_tensor * build_inp_cross_embd() const;
     ggml_tensor * build_inp_pos_bucket_enc() const;
     ggml_tensor * build_inp_pos_bucket_dec() const;
     ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
 
+    llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
+
     //
     // attention
     //
@@ -589,22 +618,62 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
+    ggml_tensor * build_attn(
+            llm_graph_input_mem_hybrid * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
+            ggml_tensor * kq_b,
+            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
+                  float   kq_scale,
+                    int   il) const;
     //
     // recurrent
     //
 
-    ggml_tensor * build_recurrent_state(
-             ggml_cgraph * gf,
-             ggml_tensor * s,
-             ggml_tensor * state_copy,
-                 int32_t   state_size,
-                 int32_t   n_seqs,
-                    bool   avoid_copies = false) const;
+    // TODO: avoid notion of "kv"
+    // TODO: move this implementation to llama_memory_recurrent.
+    //       this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
+    //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
+    //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
+    //         `llama_memory_recurrent`
+    ggml_tensor * build_rs(
+            ggml_cgraph * gf,
+            ggml_tensor * s,
+            ggml_tensor * state_copy,
+                int32_t   state_size,
+                int32_t   n_seqs,
+               uint32_t   n_kv,
+               uint32_t   kv_head,
+               uint32_t   kv_size,
+                int32_t   rs_zero,
+                   bool   avoid_copies = false) const;
+
+    llm_graph_input_rs * build_rs_inp() const;
+
+    ggml_tensor * build_rs(
+            llm_graph_input_rs * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * s,
+                int32_t   state_size,
+                int32_t   n_seqs,
+                   bool   avoid_copies = false) const;
+
+    ggml_tensor * build_rs(
+            llm_graph_input_mem_hybrid * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * s,
+                int32_t   state_size,
+                int32_t   n_seqs,
+                   bool   avoid_copies = false) const;
 
     ggml_tensor * build_rwkv_token_shift_load(
-             ggml_cgraph * gf,
-             ggml_tensor * state_copy,
-      const llama_ubatch & ubatch,
+        llm_graph_input_rs * inp,
+               ggml_cgraph * gf,
+        const llama_ubatch & ubatch,
                      int   il) const;
 
     ggml_tensor * build_rwkv_token_shift_store(
index 1499eb08a5dd9246f182dc3545d4e23aecc5ca29..b40566ced99eed9c9554182bacf5d7f738125e3c 100644 (file)
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
     return n_embd_head_v * n_head_kv;
 }
 
-uint32_t llama_hparams::n_embd_k_s() const {
+uint32_t llama_hparams::n_embd_r() const {
     if (wkv_head_size != 0) {
         // for RWKV models
         return token_shift_count * n_embd;
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
     return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
 }
 
-uint32_t llama_hparams::n_embd_v_s() const {
+uint32_t llama_hparams::n_embd_s() const {
     if (wkv_head_size != 0) {
         // corresponds to RWKV's wkv_states size
         return n_embd * wkv_head_size;
@@ -86,6 +86,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
     return ssm_d_state * ssm_d_inner;
 }
 
+bool llama_hparams::is_recurrent(uint32_t il) const {
+    return recurrent_layer_arr[il];
+}
+
 bool llama_hparams::is_swa(uint32_t il) const {
     if (il < n_layer) {
         return swa_layers[il];
index b2bcb8b01a18b8e07476cb332e6fd356687db0f8..82bb5b60849460a37bfcf74765daed7e585345ca 100644 (file)
@@ -115,6 +115,9 @@ struct llama_hparams {
     uint32_t ssm_d_state = 0;
     uint32_t ssm_dt_rank = 0;
 
+    // for hybrid state space models
+    std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
+
     bool ssm_dt_b_c_rms = false;
 
     float f_clamp_kqv      = 0.0f;
@@ -181,10 +184,13 @@ struct llama_hparams {
 
     // dimension of the rolling state embeddings
     // corresponds to Mamba's conv_states size or RWKV's token_shift states size
-    uint32_t n_embd_k_s() const;
+    uint32_t n_embd_r() const;
 
     // dimension of the recurrent state embeddings
-    uint32_t n_embd_v_s() const;
+    uint32_t n_embd_s() const;
+
+    // whether or not the given layer is recurrent (for hybrid models)
+    bool is_recurrent(uint32_t il) const;
 
     bool is_swa(uint32_t il) const;
 };
diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp
deleted file mode 100644 (file)
index 8f6f120..0000000
+++ /dev/null
@@ -1,1115 +0,0 @@
-#include "llama-kv-cache-recurrent.h"
-
-#include "llama-impl.h"
-#include "llama-io.h"
-#include "llama-batch.h"
-#include "llama-model.h"
-
-#include <algorithm>
-#include <cassert>
-#include <limits>
-#include <map>
-#include <stdexcept>
-
-//
-// llama_kv_cache_recurrent
-//
-
-llama_kv_cache_recurrent::llama_kv_cache_recurrent(
-        const llama_model & model,
-                ggml_type   type_k,
-                ggml_type   type_v,
-                     bool   offload,
-                 uint32_t   kv_size,
-                 uint32_t   n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
-    const int32_t n_layer = hparams.n_layer;
-
-    LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
-            __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
-
-    head = 0;
-    size = kv_size;
-    used = 0;
-
-    cells.clear();
-    cells.resize(kv_size);
-
-    // create a context for each buffer type
-    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
-    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            ggml_init_params params = {
-                /*.mem_size   =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-
-            ggml_context * ctx = ggml_init(params);
-            if (!ctx) {
-                return nullptr;
-            }
-
-            ctx_map[buft] = ctx;
-            ctxs.emplace_back(ctx);
-
-            return ctx;
-        }
-
-        return it->second;
-    };
-
-    k_l.reserve(n_layer);
-    v_l.reserve(n_layer);
-
-    for (int i = 0; i < n_layer; i++) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
-
-        const char * dev_name = "CPU";
-
-        ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
-
-        if (offload) {
-            auto * dev = model.dev_layer(i);
-            buft = ggml_backend_dev_buffer_type(dev);
-
-            dev_name = ggml_backend_dev_name(dev);
-        }
-
-        LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
-
-        ggml_context * ctx = ctx_for_buft(buft);
-        if (!ctx) {
-            throw std::runtime_error("failed to create ggml context for kv cache");
-        }
-
-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
-        ggml_format_name(k, "cache_k_l%d", i);
-        ggml_format_name(v, "cache_v_l%d", i);
-        k_l.push_back(k);
-        v_l.push_back(v);
-    }
-
-    // allocate tensors and initialize the buffers to avoid NaNs in the padding
-    for (auto it : ctx_map) {
-        auto * buft = it.first;
-        auto * ctx  = it.second;
-
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-        if (!buf) {
-            throw std::runtime_error("failed to allocate buffer for kv cache");
-        }
-        ggml_backend_buffer_clear(buf, 0);
-        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
-        bufs.emplace_back(buf);
-    }
-
-    {
-        const size_t memory_size_k = size_k_bytes();
-        const size_t memory_size_v = size_v_bytes();
-
-        LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
-                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
-                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
-    }
-}
-
-void llama_kv_cache_recurrent::clear(bool data) {
-    for (int32_t i = 0; i < (int32_t) size; ++i) {
-        cells[i].pos = -1;
-        cells[i].seq_id.clear();
-        cells[i].src = -1;
-        cells[i].tail = -1;
-    }
-
-    head = 0;
-    used = 0;
-
-    if (data) {
-        for (auto & buf : bufs) {
-            ggml_backend_buffer_clear(buf.get(), 0);
-        }
-    }
-}
-
-bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    uint32_t new_head = size;
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    // models like Mamba or RWKV can't have a state partially erased
-    if (seq_id >= (int64_t) size) {
-        // could be fatal
-        return false;
-    }
-    if (0 <= seq_id) {
-        int32_t & tail_id = cells[seq_id].tail;
-        if (tail_id >= 0) {
-            const kv_cell & cell = cells[tail_id];
-            // partial intersection is invalid
-            if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
-                return false;
-            }
-            // invalidate tails which will be cleared
-            if (p0 <= cell.pos && cell.pos < p1) {
-                tail_id = -1;
-            }
-        }
-    } else {
-        // seq_id is negative, then the range should include everything or nothing
-        if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
-            return false;
-        }
-    }
-
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].pos >= p0 && cells[i].pos < p1) {
-            if (seq_id < 0) {
-                cells[i].seq_id.clear();
-            } else if (cells[i].has_seq_id(seq_id)) {
-                cells[i].seq_id.erase(seq_id);
-            } else {
-                continue;
-            }
-            if (cells[i].is_empty()) {
-                // keep count of the number of used cells
-                if (cells[i].pos >= 0) {
-                    used--;
-                }
-                cells[i].pos = -1;
-                cells[i].src = -1;
-                if (new_head == size) {
-                    new_head = i;
-                }
-            }
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != size && new_head < head) {
-        head = new_head;
-    }
-
-    return true;
-}
-
-void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    if (seq_id_src == seq_id_dst) {
-        return;
-    }
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
-        kv_cell & tail_src = cells[seq_id_src];
-        kv_cell & tail_dst = cells[seq_id_dst];
-        if (tail_dst.tail >= 0) {
-            // clear destination seq_id if it wasn't empty
-            kv_cell & cell_dst = cells[tail_dst.tail];
-
-            cell_dst.seq_id.erase(seq_id_dst);
-            tail_dst.tail = -1;
-            if (cell_dst.seq_id.empty()) {
-                cell_dst.pos = -1;
-                cell_dst.src = -1;
-                used -= 1;
-            }
-        }
-        if (tail_src.tail >= 0) {
-            kv_cell & cell_src = cells[tail_src.tail];
-
-            cell_src.seq_id.insert(seq_id_dst);
-            tail_dst.tail = tail_src.tail;
-        }
-    }
-}
-
-void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
-    uint32_t new_head = size;
-
-    for (uint32_t i = 0; i < size; ++i) {
-        if ((llama_seq_id) i != seq_id) {
-            cells[i].tail = -1;
-        }
-
-        if (!cells[i].has_seq_id(seq_id)) {
-            if (cells[i].pos >= 0) {
-                used--;
-            }
-
-            cells[i].pos = -1;
-            cells[i].src = -1;
-            cells[i].seq_id.clear();
-
-            if (new_head == size){
-                new_head = i;
-            }
-        } else {
-            cells[i].seq_id.clear();
-            cells[i].seq_id.insert(seq_id);
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != size && new_head < head) {
-        head = new_head;
-    }
-}
-
-void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
-    if (shift == 0) {
-        return;
-    }
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    // If there is no range then return early to avoid looping over the
-    if (p0 == p1) {
-        return;
-    }
-
-    // for Mamba-like or RWKV models, only the pos needs to be shifted
-    if (0 <= seq_id && seq_id < (int64_t) size) {
-        const int32_t tail_id = cells[seq_id].tail;
-        if (tail_id >= 0) {
-            kv_cell & cell = cells[tail_id];
-            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                cell.pos += shift;
-            }
-        }
-    }
-}
-
-void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
-    if (d == 1) {
-        return;
-    }
-
-    if (p0 < 0) {
-        p0 = 0;
-    }
-
-    if (p1 < 0) {
-        p1 = std::numeric_limits<llama_pos>::max();
-    }
-
-    // If there is no range then return early to avoid looping over the cache.
-    if (p0 == p1) {
-        return;
-    }
-
-    // for Mamba-like or RWKV models, only the pos needs to be changed
-    if (0 <= seq_id && seq_id < (int64_t) size) {
-        const int32_t tail_id = cells[seq_id].tail;
-        if (tail_id >= 0) {
-            kv_cell & cell = cells[tail_id];
-            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                cell.pos /= d;
-            }
-        }
-    }
-}
-
-llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
-    llama_pos result = std::numeric_limits<llama_pos>::max();
-
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id)) {
-            result = std::min(result, cells[i].pos);
-        }
-    }
-
-    if (result == std::numeric_limits<llama_pos>::max()) {
-        result = -1;
-    }
-
-    return result;
-}
-
-llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
-    llama_pos result = -1;
-
-    for (uint32_t i = 0; i < size; ++i) {
-        if (cells[i].has_seq_id(seq_id)) {
-            result = std::max(result, cells[i].pos);
-        }
-    }
-
-    return result;
-}
-
-llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
-    auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
-
-    std::vector<llama_ubatch> ubatches;
-
-    while (sbatch.n_tokens > 0) {
-        llama_ubatch ubatch;
-
-        if (embd_all) {
-            // if all tokens are output, split by sequence
-            ubatch = sbatch.split_seq(n_ubatch);
-        } else {
-            ubatch = sbatch.split_equal(n_ubatch);
-        }
-
-        ubatches.push_back(ubatch);
-    }
-
-    if (!prepare(ubatches)) {
-        return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
-    }
-
-    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
-}
-
-llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
-    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
-}
-
-llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
-    GGML_UNUSED(lctx);
-    GGML_UNUSED(optimize);
-
-    return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
-}
-
-bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
-    // simply remember the full state because it is very small for this type of cache
-    // TODO: optimize
-    auto org_cells = cells;
-    auto org_used = used;
-    auto org_head = head;
-
-    bool success = true;
-
-    for (const auto & ubatch : ubatches) {
-        if (!find_slot(ubatch)) {
-            success = false;
-            break;
-        }
-    }
-
-    // restore the original state
-    cells = std::move(org_cells);
-    used = org_used;
-    head = org_head;
-
-    return success;
-}
-
-bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
-    const uint32_t n_seqs = ubatch.n_seqs;
-
-    const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
-
-    // if we have enough unused cells before the current head ->
-    //   better to start searching from the beginning of the cache, hoping to fill it
-    if (head > used + 2*n_seqs) {
-        head = 0;
-    }
-
-    // For recurrent state architectures (like Mamba or RWKV),
-    // each cache cell can store the state for a whole sequence.
-    // A slot should be always be contiguous.
-
-    // can only process batches with an equal number of new tokens in each sequence
-    GGML_ASSERT(ubatch.equal_seqs);
-
-    int32_t min = size - 1;
-    int32_t max = 0;
-
-    // everything should fit if all seq_ids are smaller than the max
-    for (uint32_t s = 0; s < n_seqs; ++s) {
-        const uint32_t n_seq_id = ubatch.n_seq_id[s];
-        for (uint32_t j = 0; j < n_seq_id; ++j) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][j];
-
-            if (seq_id < 0 || (uint32_t) seq_id >= size) {
-                // too big seq_id
-                // TODO: would it be possible to resize the cache instead?
-                LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
-                return false;
-            }
-            if (j > 0) {
-                kv_cell & seq = cells[seq_id];
-                if (seq.tail >= 0) {
-                    kv_cell & cell = cells[seq.tail];
-                    // clear cells from seq_ids that become shared
-                    // (should not normally happen, but let's handle it anyway)
-                    cell.seq_id.erase(seq_id);
-                    seq.tail = -1;
-                    if (cell.seq_id.empty()) {
-                        cell.pos = -1;
-                        cell.src = -1;
-                        used -= 1;
-                    }
-                }
-            }
-        }
-    }
-
-#ifndef NDEBUG
-    {
-        std::vector<int32_t> tails_verif;
-        tails_verif.assign(size, -1);
-        for (uint32_t i = 0; i < size; ++i) {
-            kv_cell & cell = cells[i];
-            for (llama_seq_id seq_id : cell.seq_id) {
-                if (tails_verif[seq_id] != -1) {
-                    LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
-                }
-                tails_verif[seq_id] = i;
-            }
-        }
-        for (uint32_t i = 0; i < size; ++i) {
-            if (tails_verif[i] != cells[i].tail) {
-                LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
-            }
-        }
-    }
-#endif
-
-    // find next empty cell
-    uint32_t next_empty_cell = head;
-
-    for (uint32_t i = 0; i < size; ++i) {
-        if (next_empty_cell >= size) { next_empty_cell -= size; }
-        kv_cell & cell = cells[next_empty_cell];
-        if (cell.is_empty()) { break; }
-        next_empty_cell += 1;
-    }
-
-    // find usable cell range
-    for (uint32_t s = 0; s < n_seqs; ++s) {
-        const llama_seq_id seq_id = ubatch.seq_id[s][0];
-        kv_cell & seq_meta = cells[seq_id];
-        bool has_cell = false;
-        if (seq_meta.tail >= 0) {
-            kv_cell & cell = cells[seq_meta.tail];
-            GGML_ASSERT(cell.has_seq_id(seq_id));
-            // does this seq_id "own" the cell?
-            if (cell.seq_id.size() == 1) { has_cell = true; }
-        }
-        if (!has_cell) {
-            kv_cell & empty_cell = cells[next_empty_cell];
-            GGML_ASSERT(empty_cell.is_empty());
-            // copy old tail into the empty cell
-            if (seq_meta.tail >= 0) {
-                kv_cell & orig_cell = cells[seq_meta.tail];
-                empty_cell.pos = orig_cell.pos;
-                empty_cell.src = orig_cell.src;
-                orig_cell.seq_id.erase(seq_id);
-                empty_cell.seq_id.insert(seq_id); // will be overwritten
-                GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
-            }
-            seq_meta.tail = next_empty_cell;
-            // find next empty cell
-            if (s + 1 < n_seqs) {
-                for (uint32_t i = 0; i < size; ++i) {
-                    next_empty_cell += 1;
-                    if (next_empty_cell >= size) { next_empty_cell -= size; }
-                    kv_cell & cell = cells[next_empty_cell];
-                    if (cell.is_empty()) { break; }
-                }
-            }
-        }
-        if (min > seq_meta.tail) { min = seq_meta.tail; }
-        if (max < seq_meta.tail) { max = seq_meta.tail; }
-    }
-
-    // gather and re-order
-    for (uint32_t s = 0; s < n_seqs; ++s) {
-        const int32_t dst_id = s + min;
-        const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
-        if (dst_id != src_id) {
-            kv_cell & dst_cell = cells[dst_id];
-            kv_cell & src_cell = cells[src_id];
-
-            std::swap(dst_cell.pos, src_cell.pos);
-            std::swap(dst_cell.src, src_cell.src);
-            std::swap(dst_cell.seq_id, src_cell.seq_id);
-
-            // swap tails
-            for (uint32_t i = 0; i < size; ++i) {
-                int32_t & tail = cells[i].tail;
-                if (tail == src_id) {
-                    tail = dst_id;
-                } else if (tail == dst_id) {
-                    tail = src_id;
-                }
-            }
-        }
-    }
-
-    // update the pos of the used seqs
-    for (uint32_t s = 0; s < n_seqs; ++s) {
-        const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
-        const int32_t cell_id = s + min;
-        kv_cell & cell = cells[cell_id];
-
-        if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
-            // What should happen when the pos backtracks or skips a value?
-            // Clearing the state mid-batch would require special-casing which isn't done.
-            LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
-                __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
-        }
-        cell.pos = last_pos;
-        cell.seq_id.clear();
-        for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
-            const llama_seq_id seq_id = ubatch.seq_id[s][j];
-            cell.seq_id.insert(seq_id);
-            cells[seq_id].tail = cell_id;
-        }
-    }
-
-    // Find first cell without src refs, to use as the zero-ed state
-    {
-        // TODO: bake-in src refcounts in the cell metadata
-        std::vector<int32_t> refcounts(size, 0);
-        for (size_t i = 0; i < size; ++i) {
-            const int32_t src = cells[i].src;
-            if (src >= 0) {
-                refcounts[src] += 1;
-            }
-        }
-
-        rs_z = -1;
-        for (int i = min; i <= max; ++i) {
-            if (refcounts[i] == 0) {
-                rs_z = i;
-                break;
-            }
-        }
-
-        for (int i = min; i <= max; ++i) {
-            if (cells[i].src < 0) {
-                GGML_ASSERT(rs_z >= 0);
-                cells[i].src0 = rs_z;
-            } else {
-                // Stage the source ids for all used cells to allow correct seq_* behavior
-                // and still make these values available when setting the inputs
-                cells[i].src0 = cells[i].src;
-            }
-            cells[i].src = i; // avoid moving or clearing twice
-        }
-    }
-
-    // allow getting the range of used cells, from head to head + n
-    head = min;
-    n    = max - min + 1;
-    used = std::count_if(cells.begin(), cells.end(),
-        [](const kv_cell & cell){ return !cell.is_empty(); });
-
-    // sanity check
-    return n >= n_seqs;
-}
-
-bool llama_kv_cache_recurrent::get_can_shift() const {
-    // shifting the pos is trivial for recurrent models
-    return true;
-}
-
-size_t llama_kv_cache_recurrent::total_size() const {
-    size_t size = 0;
-    for (const auto & buf : bufs) {
-        size += ggml_backend_buffer_get_size(buf.get());
-    }
-
-    return size;
-}
-
-size_t llama_kv_cache_recurrent::size_k_bytes() const {
-    size_t size_k_bytes = 0;
-
-    for (const auto & k : k_l) {
-        size_k_bytes += ggml_nbytes(k);
-    }
-
-    return size_k_bytes;
-}
-
-size_t llama_kv_cache_recurrent::size_v_bytes() const {
-    size_t size_v_bytes = 0;
-
-    for (const auto & v : v_l) {
-        size_v_bytes += ggml_nbytes(v);
-    }
-
-    return size_v_bytes;
-}
-
-void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
-    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
-    uint32_t cell_count = 0;
-
-    // Count the number of cells with the specified seq_id
-    // Find all the ranges of cells with this seq id (or all, when -1)
-    uint32_t cell_range_begin = size;
-    for (uint32_t i = 0; i < size; ++i) {
-        const auto & cell = cells[i];
-        if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
-            ++cell_count;
-            if (cell_range_begin == size) {
-                cell_range_begin = i;
-            }
-        } else {
-            if (cell_range_begin != size) {
-                cell_ranges.emplace_back(cell_range_begin, i);
-                cell_range_begin = size;
-            }
-        }
-    }
-    if (cell_range_begin != size) {
-        cell_ranges.emplace_back(cell_range_begin, size);
-    }
-
-    // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-    uint32_t cell_count_check = 0;
-    for (const auto & range : cell_ranges) {
-        cell_count_check += range.second - range.first;
-    }
-    GGML_ASSERT(cell_count == cell_count_check);
-
-    io.write(&cell_count, sizeof(cell_count));
-
-    state_write_meta(io, cell_ranges, seq_id);
-    state_write_data(io, cell_ranges);
-}
-
-void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
-    uint32_t cell_count;
-    io.read_to(&cell_count, sizeof(cell_count));
-
-    bool res = true;
-
-    res = res && state_read_meta(io, cell_count, seq_id);
-    res = res && state_read_data(io, cell_count);
-
-    if (!res) {
-        if (seq_id == -1) {
-            clear(true);
-        } else {
-            seq_rm(seq_id, -1, -1);
-        }
-        throw std::runtime_error("failed to restore kv cache");
-    }
-}
-
-void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
-    for (const auto & range : cell_ranges) {
-        for (uint32_t i = range.first; i < range.second; ++i) {
-            const auto & cell = cells[i];
-            const llama_pos pos      = cell.pos;
-            const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
-
-            io.write(&pos,      sizeof(pos));
-            io.write(&n_seq_id, sizeof(n_seq_id));
-
-            if (n_seq_id) {
-                for (auto seq_id : cell.seq_id) {
-                    io.write(&seq_id, sizeof(seq_id));
-                }
-            }
-        }
-    }
-}
-
-void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
-    const uint32_t v_trans = 0;
-    const uint32_t n_layer = hparams.n_layer;
-
-    io.write(&v_trans, sizeof(v_trans));
-    io.write(&n_layer, sizeof(n_layer));
-
-    std::vector<uint8_t> tmp_buf;
-
-    // Iterate and write all the keys first, each row is a cell
-    // Get whole range at a time
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-        // Write key type
-        const int32_t k_type_i = (int32_t)k_l[il]->type;
-        io.write(&k_type_i, sizeof(k_type_i));
-
-        // Write row size of key
-        const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
-        io.write(&k_size_row, sizeof(k_size_row));
-
-        // Read each range of cells of k_size length each into tmp_buf and write out
-        for (const auto & range : cell_ranges) {
-            const size_t range_size = range.second - range.first;
-            const size_t buf_size = range_size * k_size_row;
-            io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
-        }
-    }
-
-    if (!v_trans) {
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Write value type
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
-            io.write(&v_type_i, sizeof(v_type_i));
-
-            // Write row size of value
-            const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
-            io.write(&v_size_row, sizeof(v_size_row));
-
-            // Read each range of cells of v_size length each into tmp_buf and write out
-            for (const auto & range : cell_ranges) {
-                const size_t range_size = range.second - range.first;
-                const size_t buf_size = range_size * v_size_row;
-                io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
-            }
-        }
-    } else {
-        // When v is transposed, we also need the element size and get the element ranges from each row
-        const uint32_t kv_size = size;
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Write value type
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
-            io.write(&v_type_i, sizeof(v_type_i));
-
-            // Write element size
-            const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
-            io.write(&v_size_el, sizeof(v_size_el));
-
-            // Write GQA embedding size
-            io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
-
-            // For each row, we get the element values of each cell
-            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                // Read each range of cells of v_size_el length each into tmp_buf and write out
-                for (const auto & range : cell_ranges) {
-                    const size_t range_size = range.second - range.first;
-                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                    const size_t buf_size = range_size * v_size_el;
-                    io.write_tensor(v_l[il], src_offset, buf_size);
-                }
-            }
-        }
-    }
-}
-
-bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
-    if (dest_seq_id != -1) {
-        // single sequence
-
-        seq_rm(dest_seq_id, -1, -1);
-
-        llama_sbatch sbatch;
-        llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
-
-        batch.n_tokens = cell_count;
-        batch.n_seq_tokens = cell_count;
-        batch.n_seqs = 1;
-
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            llama_pos pos;
-            uint32_t n_seq_id;
-
-            io.read_to(&pos,      sizeof(pos));
-            io.read_to(&n_seq_id, sizeof(n_seq_id));
-
-            if (n_seq_id != 0) {
-                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
-                return false;
-            }
-
-            batch.pos[i] = pos;
-        }
-        batch.n_seq_id[0] = 1;
-        batch.seq_id[0] = &dest_seq_id;
-
-        if (!find_slot(batch)) {
-            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
-            return false;
-        }
-
-        // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
-        // Assume that this is one contiguous block of cells
-        GGML_ASSERT(head + cell_count <= size);
-        GGML_ASSERT(cells[head].pos == batch.pos[0]);
-        GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
-        GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
-        GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
-    } else {
-        // whole KV cache restore
-
-        if (cell_count > size) {
-            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
-            return false;
-        }
-
-        clear(true);
-
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            kv_cell & cell = cells[i];
-
-            llama_pos pos;
-            uint32_t  n_seq_id;
-
-            io.read_to(&pos,      sizeof(pos));
-            io.read_to(&n_seq_id, sizeof(n_seq_id));
-
-            cell.pos = pos;
-
-            for (uint32_t j = 0; j < n_seq_id; ++j) {
-                llama_seq_id seq_id;
-                io.read_to(&seq_id, sizeof(seq_id));
-
-                // TODO: llama_kv_cache_recurrent should have a notion of max sequences
-                //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
-                if (seq_id < 0) {
-                    //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
-                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
-                    return false;
-                }
-
-                cell.seq_id.insert(seq_id);
-
-                int32_t & tail = cells[seq_id].tail;
-                if (tail != -1) {
-                    LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
-                    return false;
-                }
-                tail = i;
-            }
-        }
-
-        head = 0;
-        used = cell_count;
-    }
-
-    for (uint32_t i = 0; i < cell_count; ++i) {
-        uint32_t cell_id = head + i;
-        // make sure the recurrent states will keep their restored state
-        cells[cell_id].src = cell_id;
-    }
-
-    return true;
-}
-
-bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
-    uint32_t v_trans;
-    uint32_t n_layer;
-    io.read_to(&v_trans, sizeof(v_trans));
-    io.read_to(&n_layer, sizeof(n_layer));
-
-    if (n_layer != hparams.n_layer) {
-        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
-        return false;
-    }
-    if (cell_count > size) {
-        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
-        return false;
-    }
-    if (false != (bool) v_trans) {
-        LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
-        return false;
-    }
-
-    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-        // Read type of key
-        int32_t k_type_i_ref;
-        io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-        const int32_t k_type_i = (int32_t) k_l[il]->type;
-        if (k_type_i != k_type_i_ref) {
-            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
-            return false;
-        }
-
-        // Read row size of key
-        uint64_t k_size_row_ref;
-        io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-        const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
-        if (k_size_row != k_size_row_ref) {
-            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
-            return false;
-        }
-
-        if (cell_count) {
-            // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
-        }
-    }
-
-    if (!v_trans) {
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Read type of value
-            int32_t v_type_i_ref;
-            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return false;
-            }
-
-            // Read row size of value
-            uint64_t v_size_row_ref;
-            io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-            const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
-            if (v_size_row != v_size_row_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
-                return false;
-            }
-
-            if (cell_count) {
-                // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
-            }
-        }
-    } else {
-        // For each layer, read the values for each cell (transposed)
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-            // Read type of value
-            int32_t v_type_i_ref;
-            io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-            const int32_t v_type_i = (int32_t)v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return false;
-            }
-
-            // Read element size of value
-            uint32_t v_size_el_ref;
-            io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-            const size_t v_size_el = ggml_type_size(v_l[il]->type);
-            if (v_size_el != v_size_el_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
-                return false;
-            }
-
-            // Read GQA embedding size
-            uint32_t n_embd_v_gqa_ref;
-            io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
-            if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
-                return false;
-            }
-
-            if (cell_count) {
-                // For each row in the transposed matrix, read the values for the whole cell range
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (head + j * size) * v_size_el;
-                    ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
-                }
-            }
-        }
-    }
-
-    return true;
-}
-
-//
-// llama_kv_cache_recurrent_state
-//
-
-llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
-
-llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
-        llama_memory_status status,
-        llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
-}
-
-llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
-        llama_memory_status status,
-        llama_kv_cache_recurrent * kv,
-        llama_sbatch sbatch,
-        std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
-
-llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
-
-bool llama_kv_cache_recurrent_state::next() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    if (++i_next >= ubatches.size()) {
-        return false;
-    }
-
-    return true;
-}
-
-bool llama_kv_cache_recurrent_state::apply() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    kv->find_slot(ubatches[i_next]);
-
-    return true;
-}
-
-std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return sbatch.out_ids;
-}
-
-llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
-    return status;
-}
-
-const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
-
-    return ubatches[i_next];
-}
-
-uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
-    return is_full ? kv->size : kv->n;
-}
-
-uint32_t llama_kv_cache_recurrent_state::get_head() const {
-    return is_full ? 0 : kv->head;
-}
-
-int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
-    return is_full ? 0 : kv->rs_z;
-}
-
-uint32_t llama_kv_cache_recurrent_state::get_size() const {
-    return kv->size;
-}
-
-ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
-    return kv->k_l[il];
-}
-
-ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
-    return kv->v_l[il];
-}
-
-int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
-    return  kv->cells[i + kv->head].src0;
-}
diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h
deleted file mode 100644 (file)
index f9b01a6..0000000
+++ /dev/null
@@ -1,184 +0,0 @@
-#pragma once
-
-#include "llama-batch.h"
-#include "llama-graph.h"
-#include "llama-memory.h"
-
-#include <set>
-#include <vector>
-
-//
-// llama_kv_cache_recurrent
-//
-
-// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
-//       see the implementation of llama_kv_cache_unified_state_i for an example how to do it
-class llama_kv_cache_recurrent : public llama_memory_i {
-public:
-    llama_kv_cache_recurrent(
-            const llama_model & model,
-                    ggml_type   type_k,
-                    ggml_type   type_v,
-                         bool   offload,
-                     uint32_t   kv_size,
-                     uint32_t   n_seq_max);
-
-    ~llama_kv_cache_recurrent() = default;
-
-    //
-    // llama_memory_i
-    //
-
-    llama_memory_state_ptr init_batch(
-            const llama_batch & batch,
-            uint32_t n_ubatch,
-            bool embd_all) override;
-
-    llama_memory_state_ptr init_full() override;
-
-    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
-
-    void clear(bool data) override;
-
-    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
-    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id)                                                          override;
-    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
-    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
-
-    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
-    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
-
-    bool prepare(const std::vector<llama_ubatch> & ubatches);
-
-    // find a contiguous slot of kv cells and emplace the ubatch there
-    bool find_slot(const llama_ubatch & ubatch);
-
-    bool get_can_shift() const override;
-
-    // state write/load
-
-    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
-
-    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
-    uint32_t size = 0; // total number of cells, shared across all sequences
-    uint32_t used = 0; // used cells (i.e. at least one seq_id)
-
-    // computed before each graph build
-    uint32_t n = 0;
-
-    // first zero-ed state
-    int32_t rs_z = -1;
-
-    // TODO: optimize for recurrent state needs
-    struct kv_cell {
-        llama_pos pos  = -1;
-        int32_t   src  = -1; // used to know where states should be copied from
-        int32_t   src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
-        int32_t   tail = -1;
-
-        std::set<llama_seq_id> seq_id;
-
-        bool has_seq_id(const llama_seq_id & id) const {
-            return seq_id.find(id) != seq_id.end();
-        }
-
-        bool is_empty() const {
-            return seq_id.empty();
-        }
-
-        bool is_same_seq(const kv_cell & other) const {
-            return seq_id == other.seq_id;
-        }
-    };
-
-    std::vector<kv_cell> cells;
-
-    std::vector<ggml_tensor *> k_l; // per layer
-    std::vector<ggml_tensor *> v_l;
-
-private:
-    //const llama_model & model;
-    const llama_hparams & hparams;
-
-    const uint32_t n_seq_max = 1;
-
-    std::vector<ggml_context_ptr>        ctxs;
-    std::vector<ggml_backend_buffer_ptr> bufs;
-
-    size_t total_size() const;
-
-    size_t size_k_bytes() const;
-    size_t size_v_bytes() const;
-
-    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
-    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
-
-    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
-    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
-};
-
-class llama_kv_cache_recurrent_state : public llama_memory_state_i {
-public:
-    // used for errors
-    llama_kv_cache_recurrent_state(llama_memory_status status);
-
-    // used to create a full-cache state
-    llama_kv_cache_recurrent_state(
-            llama_memory_status status,
-            llama_kv_cache_recurrent * kv);
-
-    // used to create a state from a batch
-    llama_kv_cache_recurrent_state(
-            llama_memory_status status,
-            llama_kv_cache_recurrent * kv,
-            llama_sbatch sbatch,
-            std::vector<llama_ubatch> ubatches);
-
-    virtual ~llama_kv_cache_recurrent_state();
-
-    //
-    // llama_memory_state_i
-    //
-
-    bool next()  override;
-    bool apply() override;
-
-    std::vector<int64_t> & out_ids() override;
-
-    llama_memory_status  get_status() const override;
-    const llama_ubatch & get_ubatch() const override;
-
-    //
-    // llama_kv_cache_recurrent_state specific API
-    //
-
-    uint32_t get_n_kv() const;
-    uint32_t get_head() const;
-    int32_t  get_rs_z() const;
-    uint32_t get_size() const;
-
-    ggml_tensor * get_k_l(int32_t il) const;
-    ggml_tensor * get_v_l(int32_t il) const;
-
-    int32_t s_copy(int i) const;
-
-private:
-    const llama_memory_status status;
-
-    llama_kv_cache_recurrent * kv;
-
-    llama_sbatch sbatch;
-
-    size_t i_next = 0;
-
-    std::vector<llama_ubatch> ubatches;
-
-    //
-    // data needed for building the compute graph for the current ubatch:
-    // TODO: extract all the state like `head` and `n` here
-    //
-
-    const bool is_full = false;
-};
index a4a4c2b1b859de2c2b4be3904f8ecdf0c2e27d9c..a869b1de8c2a321bfa3ed0945a726de83d48a244 100644 (file)
@@ -197,21 +197,19 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
-        llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
-    state_base = kv->get_base()->init_full();
-    state_swa  = kv->get_swa ()->init_full();
-
-    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+        llama_kv_cache_unified_iswa * kv) :
+    state_base(kv->get_base()->init_full()),
+    state_swa (kv->get_swa ()->init_full()),
+    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 }
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
         llama_kv_cache_unified_iswa * kv,
         llama_context * lctx,
-        bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
-    state_base = kv->get_base()->init_update(lctx, optimize);
-    state_swa  = kv->get_swa ()->init_update(lctx, optimize);
-
-    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+        bool optimize) :
+    state_base(kv->get_base()->init_update(lctx, optimize)),
+    state_swa (kv->get_swa ()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 }
 
 llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
@@ -219,15 +217,13 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
         llama_sbatch sbatch,
         std::vector<uint32_t> heads_base,
         std::vector<uint32_t> heads_swa,
-        std::vector<llama_ubatch> ubatches)
-        : status(LLAMA_MEMORY_STATUS_SUCCESS),
-        sbatch(std::move(sbatch)),
-        ubatches(std::move(ubatches)) {
+        std::vector<llama_ubatch> ubatches) :
+    sbatch(std::move(sbatch)),
+    ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
-    state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa),  this->ubatches));
-
-    status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
+    state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
+    state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa),  this->ubatches)),
+    status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
 }
 
 llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
index 6e941e1a41b88b31f0838c4834ae6eda830a3455..813eaf39b25b0202b8a7c9b463f4d882112c1ee8 100644 (file)
@@ -117,8 +117,6 @@ public:
     const llama_kv_cache_unified_state * get_swa()  const;
 
 private:
-    llama_memory_status status;
-
     //llama_kv_cache_unified_iswa * kv;
 
     llama_sbatch sbatch;
@@ -128,6 +126,8 @@ private:
 
     std::vector<llama_ubatch> ubatches;
 
-    llama_memory_state_ptr state_base;
-    llama_memory_state_ptr state_swa;
+    const llama_memory_state_ptr state_base;
+    const llama_memory_state_ptr state_swa;
+
+    const llama_memory_status status;
 };
index 3b37679859d392481c4e56bb1004acbb4004ee76..d4412288925c3d2d4bb67996531e3f486ccc7b98 100644 (file)
@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
             continue;
         }
 
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
         const char * dev_name = "CPU";
 
@@ -1430,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
     for (const auto & layer : layers) {
         const uint32_t il = layer.il;
 
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
         // Write key type
         const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1452,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
         for (const auto & layer : layers) {
             const uint32_t il = layer.il;
 
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             // Write value type
             const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1476,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
         for (const auto & layer : layers) {
             const uint32_t il = layer.il;
 
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             // Write value type
             const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1621,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
     for (const auto & layer : layers) {
         const uint32_t il = layer.il;
 
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
 
         // Read type of key
         int32_t k_type_i_ref;
@@ -1651,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         for (const auto & layer : layers) {
             const uint32_t il = layer.il;
 
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             // Read type of value
             int32_t v_type_i_ref;
@@ -1681,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
         for (const auto & layer : layers) {
             const uint32_t il = layer.il;
 
-            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+            const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             // Read type of value
             int32_t v_type_i_ref;
diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp
new file mode 100644 (file)
index 0000000..d4b260d
--- /dev/null
@@ -0,0 +1,247 @@
+#include "llama-memory-hybrid.h"
+
+#include "llama-impl.h"
+#include "llama-model.h"
+#include "llama-context.h"
+
+//
+// llama_memory_hybrid
+//
+
+llama_memory_hybrid::llama_memory_hybrid(
+    const llama_model & model,
+                         /* attn */
+            ggml_type    type_k,
+            ggml_type    type_v,
+                 bool    v_trans,
+             uint32_t    kv_size,
+             uint32_t    n_pad,
+             uint32_t    n_swa,
+       llama_swa_type    swa_type,
+                         /* recurrent */
+            ggml_type    type_r,
+            ggml_type    type_s,
+             uint32_t    rs_size,
+                         /* common */
+             uint32_t    n_seq_max,
+                 bool    offload,
+                         /* layer filters */
+      layer_filter_cb && filter_attn,
+      layer_filter_cb && filter_recr) :
+    hparams(model.hparams),
+    mem_attn(new llama_kv_cache_unified(
+        model,
+        filter_attn == nullptr ?
+            [&](int32_t il) { return !model.hparams.is_recurrent(il); }
+            : filter_attn,
+        type_k,
+        type_v,
+        v_trans,
+        offload,
+        kv_size,
+        n_seq_max,
+        n_pad,
+        n_swa,
+        swa_type
+    )),
+    mem_recr(new llama_memory_recurrent(
+        model,
+        filter_recr == nullptr ?
+            [&](int32_t il) { return model.hparams.is_recurrent(il); }
+            : filter_recr,
+        type_r,
+        type_s,
+        offload,
+        rs_size,
+        n_seq_max
+    )) {}
+
+llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
+
+    // since this includes a recurrent cache, we cannot use split_simple
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
+
+    // follow the recurrent pattern for creating the ubatch splits
+    std::vector<llama_ubatch> ubatches;
+    while (sbatch.n_tokens > 0) {
+        llama_ubatch ubatch;
+
+        if (embd_pooled) {
+            // Pooled embeddings cannot be split across ubatches (yet)
+            ubatch = sbatch.split_seq(n_ubatch);
+        } else {
+            ubatch = sbatch.split_equal(n_ubatch);
+        }
+
+        ubatches.push_back(ubatch);
+    }
+
+    // prepare the recurrent batches first
+    if (!mem_recr->prepare(ubatches)) {
+        // TODO: will the recurrent cache be in an undefined state at this point?
+        LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
+        return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    }
+
+    // prepare the attention cache
+    auto heads_attn = mem_attn->prepare(ubatches);
+    if (heads_attn.empty()) {
+        LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
+        return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    }
+
+    return std::make_unique<llama_memory_hybrid_state>(
+        this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
+}
+
+llama_memory_state_ptr llama_memory_hybrid::init_full() {
+    return std::make_unique<llama_memory_hybrid_state>(this);
+}
+
+llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
+}
+
+bool llama_memory_hybrid::get_can_shift() const {
+    // Shifting is trivially supported for recurrent
+    return mem_attn->get_can_shift();
+}
+
+void llama_memory_hybrid::clear(bool data) {
+    mem_attn->clear(data);
+    mem_recr->clear(data);
+}
+
+bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    // Try removing from the recurrent cache first since it may fail. If it does
+    // fail, the cache will not have been mutated.
+    if (!mem_recr->seq_rm(seq_id, p0, p1)) {
+        return false;
+    }
+    return mem_attn->seq_rm(seq_id, p0, p1);
+}
+
+void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+    mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
+    mem_attn->seq_keep(seq_id);
+    mem_recr->seq_keep(seq_id);
+}
+
+void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    mem_attn->seq_add(seq_id, p0, p1, shift);
+    mem_recr->seq_add(seq_id, p0, p1, shift);
+}
+
+void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    mem_attn->seq_div(seq_id, p0, p1, d);
+    mem_recr->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
+    // the min of the total cache is the max of the two caches' min values
+    return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
+}
+
+llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
+    // the max of the total cache is the min of the two caches' max values
+    return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
+}
+
+void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+    mem_attn->state_write(io, seq_id);
+    mem_recr->state_write(io, seq_id);
+}
+
+void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+    mem_attn->state_read(io, seq_id);
+    mem_recr->state_read(io, seq_id);
+}
+
+llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
+    return mem_attn.get();
+}
+
+llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
+    return mem_recr.get();
+}
+
+llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
+
+llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
+    state_attn(mem->get_mem_attn()->init_full()),
+    state_recr(mem->get_mem_recr()->init_full()),
+    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+}
+
+llama_memory_hybrid_state::llama_memory_hybrid_state(
+        llama_memory_hybrid * mem,
+              llama_context * lctx,
+                       bool   optimize) :
+    state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
+    state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
+}
+
+llama_memory_hybrid_state::llama_memory_hybrid_state(
+              llama_memory_hybrid * mem,
+                     llama_sbatch   sbatch,
+            std::vector<uint32_t>   heads_attn,
+        std::vector<llama_ubatch>   ubatches) :
+    sbatch(std::move(sbatch)),
+    ubatches(std::move(ubatches)),
+    // note: here we copy the ubatches. not sure if this is ideal
+    state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
+    state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {},                        this->ubatches)),
+    status(LLAMA_MEMORY_STATUS_SUCCESS) {
+}
+
+bool llama_memory_hybrid_state::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    state_attn->next();
+    state_recr->next();
+
+    if (++i_next >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_memory_hybrid_state::apply() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    bool res = true;
+
+    res = res & state_attn->apply();
+    res = res & state_recr->apply();
+
+    return res;
+}
+
+std::vector<int64_t> & llama_memory_hybrid_state::out_ids() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return sbatch.out_ids;
+}
+
+llama_memory_status llama_memory_hybrid_state::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+    return ubatches[i_next];
+}
+
+const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
+    return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
+}
+
+const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
+    return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
+}
diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h
new file mode 100644 (file)
index 0000000..b5700c5
--- /dev/null
@@ -0,0 +1,143 @@
+#pragma once
+
+#include "llama-batch.h"
+#include "llama-graph.h"
+#include "llama-kv-cache-unified.h"
+#include "llama-memory.h"
+#include "llama-memory-recurrent.h"
+
+#include <memory>
+#include <vector>
+
+//
+// llama_memory_hybrid
+//
+
+// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
+//   support models where each layer may be either attention-based or recurrent
+
+class llama_memory_hybrid : public llama_memory_i {
+public:
+
+    // this callback is used to filter out layers that should not be included in the cache
+    using layer_filter_cb = std::function<bool(int32_t il)>;
+
+    llama_memory_hybrid(
+        const llama_model & model,
+                            /* attn */
+                ggml_type    type_k,
+                ggml_type    type_v,
+                     bool    v_trans,
+                 uint32_t    kv_size,
+                 uint32_t    n_pad,
+                 uint32_t    n_swa,
+           llama_swa_type    swa_type,
+                             /* recurrent */
+                ggml_type    type_r,
+                ggml_type    type_s,
+                 uint32_t    rs_size,
+                             /* common */
+                 uint32_t    n_seq_max,
+                     bool    offload,
+                             /* layer filters */
+          layer_filter_cb && filter_attn = nullptr,
+          layer_filter_cb && filter_recr = nullptr);
+
+    ~llama_memory_hybrid() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_pooled) override;
+
+    llama_memory_state_ptr init_full() override;
+
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+
+    bool get_can_shift() const override;
+
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
+
+    //
+    // llama_memory_hybrid specific API
+    //
+
+    llama_kv_cache_unified * get_mem_attn() const;
+    llama_memory_recurrent * get_mem_recr() const;
+
+private:
+    const llama_hparams & hparams;
+
+    const std::unique_ptr<llama_kv_cache_unified> mem_attn;
+    const std::unique_ptr<llama_memory_recurrent> mem_recr;
+};
+
+class llama_memory_hybrid_state : public llama_memory_state_i {
+public:
+    // init failure
+    explicit llama_memory_hybrid_state(llama_memory_status status);
+
+    // init full
+    explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
+
+    // init update
+    explicit llama_memory_hybrid_state(
+        llama_memory_hybrid * mem,
+              llama_context * lctx,
+                       bool   optimize);
+
+    // init success
+    llama_memory_hybrid_state(
+              llama_memory_hybrid * mem,
+                     llama_sbatch   sbatch,
+            std::vector<uint32_t>   heads_attn,
+        std::vector<llama_ubatch>   ubatches);
+
+    ~llama_memory_hybrid_state() = default;
+
+    bool next()  override;
+    bool apply() override;
+
+    std::vector<int64_t> & out_ids() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_memory_hybrid_state
+    //
+
+    const llama_kv_cache_unified_state * get_state_attn() const;
+    const llama_memory_recurrent_state * get_state_recr() const;
+
+private:
+    llama_sbatch sbatch;
+
+    // the index of the next ubatch to process
+    size_t i_next = 0;
+
+    std::vector<llama_ubatch> ubatches;
+
+    const llama_memory_state_ptr state_attn;
+    const llama_memory_state_ptr state_recr;
+
+    const llama_memory_status status;
+};
diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp
new file mode 100644 (file)
index 0000000..c4f9a6f
--- /dev/null
@@ -0,0 +1,1116 @@
+#include "llama-memory-recurrent.h"
+
+#include "llama-impl.h"
+#include "llama-io.h"
+#include "llama-batch.h"
+#include "llama-model.h"
+
+#include <algorithm>
+#include <cassert>
+#include <limits>
+#include <map>
+#include <stdexcept>
+
+//
+// llama_memory_recurrent
+//
+
+llama_memory_recurrent::llama_memory_recurrent(
+        const llama_model &  model,
+          layer_filter_cb && filter,
+                ggml_type    type_r,
+                ggml_type    type_s,
+                     bool    offload,
+                 uint32_t    mem_size,
+                 uint32_t    n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
+    const int32_t n_layer = hparams.n_layer;
+
+    LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
+            __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
+
+    head = 0;
+    size = mem_size;
+    used = 0;
+
+    cells.clear();
+    cells.resize(mem_size);
+
+    // create a context for each buffer type
+    std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
+    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            ggml_init_params params = {
+                /*.mem_size   =*/ size_t(2u*n_layer*ggml_tensor_overhead()),
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+
+            ggml_context * ctx = ggml_init(params);
+            if (!ctx) {
+                return nullptr;
+            }
+
+            ctx_map[buft] = ctx;
+            ctxs.emplace_back(ctx);
+
+            return ctx;
+        }
+
+        return it->second;
+    };
+
+    r_l.resize(n_layer);
+    s_l.resize(n_layer);
+
+    for (int i = 0; i < n_layer; i++) {
+        if (filter && !filter(i)) {
+            LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
+            continue;
+        }
+
+        const char * dev_name = "CPU";
+
+        ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type();
+
+        if (offload) {
+            auto * dev = model.dev_layer(i);
+            buft = ggml_backend_dev_buffer_type(dev);
+
+            dev_name = ggml_backend_dev_name(dev);
+        }
+
+        LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name);
+
+        ggml_context * ctx = ctx_for_buft(buft);
+        if (!ctx) {
+            throw std::runtime_error("failed to create ggml context for kv cache");
+        }
+
+        ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
+        ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
+        ggml_format_name(r, "cache_r_l%d", i);
+        ggml_format_name(s, "cache_s_l%d", i);
+        r_l[i] = r;
+        s_l[i] = s;
+    }
+
+    // allocate tensors and initialize the buffers to avoid NaNs in the padding
+    for (auto it : ctx_map) {
+        auto * buft = it.first;
+        auto * ctx  = it.second;
+
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+        if (!buf) {
+            throw std::runtime_error("failed to allocate buffer for kv cache");
+        }
+        ggml_backend_buffer_clear(buf, 0);
+        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+        bufs.emplace_back(buf);
+    }
+
+    {
+        const size_t memory_size_r = size_r_bytes();
+        const size_t memory_size_s = size_s_bytes();
+
+        LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
+                ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
+                ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
+    }
+}
+
+void llama_memory_recurrent::clear(bool data) {
+    for (int32_t i = 0; i < (int32_t) size; ++i) {
+        cells[i].pos = -1;
+        cells[i].seq_id.clear();
+        cells[i].src = -1;
+        cells[i].tail = -1;
+    }
+
+    head = 0;
+    used = 0;
+
+    if (data) {
+        for (auto & buf : bufs) {
+            ggml_backend_buffer_clear(buf.get(), 0);
+        }
+    }
+}
+
+bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    uint32_t new_head = size;
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // models like Mamba or RWKV can't have a state partially erased
+    if (seq_id >= (int64_t) size) {
+        // could be fatal
+        return false;
+    }
+    if (0 <= seq_id) {
+        int32_t & tail_id = cells[seq_id].tail;
+        if (tail_id >= 0) {
+            const auto & cell = cells[tail_id];
+            // partial intersection is invalid
+            if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+                return false;
+            }
+            // invalidate tails which will be cleared
+            if (p0 <= cell.pos && cell.pos < p1) {
+                tail_id = -1;
+            }
+        }
+    } else {
+        // seq_id is negative, then the range should include everything or nothing
+        if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
+            return false;
+        }
+    }
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].pos >= p0 && cells[i].pos < p1) {
+            if (seq_id < 0) {
+                cells[i].seq_id.clear();
+            } else if (cells[i].has_seq_id(seq_id)) {
+                cells[i].seq_id.erase(seq_id);
+            } else {
+                continue;
+            }
+            if (cells[i].is_empty()) {
+                // keep count of the number of used cells
+                if (cells[i].pos >= 0) {
+                    used--;
+                }
+                cells[i].pos = -1;
+                cells[i].src = -1;
+                if (new_head == size) {
+                    new_head = i;
+                }
+            }
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != size && new_head < head) {
+        head = new_head;
+    }
+
+    return true;
+}
+
+void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    if (seq_id_src == seq_id_dst) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
+        auto & tail_src = cells[seq_id_src];
+        auto & tail_dst = cells[seq_id_dst];
+        if (tail_dst.tail >= 0) {
+            // clear destination seq_id if it wasn't empty
+            auto & cell_dst = cells[tail_dst.tail];
+
+            cell_dst.seq_id.erase(seq_id_dst);
+            tail_dst.tail = -1;
+            if (cell_dst.seq_id.empty()) {
+                cell_dst.pos = -1;
+                cell_dst.src = -1;
+                used -= 1;
+            }
+        }
+        if (tail_src.tail >= 0) {
+            auto & cell_src = cells[tail_src.tail];
+
+            cell_src.seq_id.insert(seq_id_dst);
+            tail_dst.tail = tail_src.tail;
+        }
+    }
+}
+
+void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
+    uint32_t new_head = size;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if ((llama_seq_id) i != seq_id) {
+            cells[i].tail = -1;
+        }
+
+        if (!cells[i].has_seq_id(seq_id)) {
+            if (cells[i].pos >= 0) {
+                used--;
+            }
+
+            cells[i].pos = -1;
+            cells[i].src = -1;
+            cells[i].seq_id.clear();
+
+            if (new_head == size){
+                new_head = i;
+            }
+        } else {
+            cells[i].seq_id.clear();
+            cells[i].seq_id.insert(seq_id);
+        }
+    }
+
+    // If we freed up a slot, set head to it so searching can start there.
+    if (new_head != size && new_head < head) {
+        head = new_head;
+    }
+}
+
+void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    if (shift == 0) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the
+    if (p0 == p1) {
+        return;
+    }
+
+    // for Mamba-like or RWKV models, only the pos needs to be shifted
+    if (0 <= seq_id && seq_id < (int64_t) size) {
+        const int32_t tail_id = cells[seq_id].tail;
+        if (tail_id >= 0) {
+            auto & cell = cells[tail_id];
+            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                cell.pos += shift;
+            }
+        }
+    }
+}
+
+void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    if (d == 1) {
+        return;
+    }
+
+    if (p0 < 0) {
+        p0 = 0;
+    }
+
+    if (p1 < 0) {
+        p1 = std::numeric_limits<llama_pos>::max();
+    }
+
+    // If there is no range then return early to avoid looping over the cache.
+    if (p0 == p1) {
+        return;
+    }
+
+    // for Mamba-like or RWKV models, only the pos needs to be changed
+    if (0 <= seq_id && seq_id < (int64_t) size) {
+        const int32_t tail_id = cells[seq_id].tail;
+        if (tail_id >= 0) {
+            auto & cell = cells[tail_id];
+            if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
+                cell.pos /= d;
+            }
+        }
+    }
+}
+
+llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
+    llama_pos result = std::numeric_limits<llama_pos>::max();
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].has_seq_id(seq_id)) {
+            result = std::min(result, cells[i].pos);
+        }
+    }
+
+    if (result == std::numeric_limits<llama_pos>::max()) {
+        result = -1;
+    }
+
+    return result;
+}
+
+llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
+    llama_pos result = -1;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (cells[i].has_seq_id(seq_id)) {
+            result = std::max(result, cells[i].pos);
+        }
+    }
+
+    return result;
+}
+
+llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
+    auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
+
+    std::vector<llama_ubatch> ubatches;
+
+    while (sbatch.n_tokens > 0) {
+        llama_ubatch ubatch;
+
+        if (embd_all) {
+            // if all tokens are output, split by sequence
+            ubatch = sbatch.split_seq(n_ubatch);
+        } else {
+            ubatch = sbatch.split_equal(n_ubatch);
+        }
+
+        ubatches.push_back(ubatch);
+    }
+
+    if (!prepare(ubatches)) {
+        return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    }
+
+    return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
+}
+
+llama_memory_state_ptr llama_memory_recurrent::init_full() {
+    return std::make_unique<llama_memory_recurrent_state>(this);
+}
+
+llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
+    GGML_UNUSED(lctx);
+    GGML_UNUSED(optimize);
+
+    return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
+}
+
+bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
+    // simply remember the full state because it is very small for this type of cache
+    // TODO: optimize
+    auto org_cells = cells;
+    auto org_used = used;
+    auto org_head = head;
+
+    bool success = true;
+
+    for (const auto & ubatch : ubatches) {
+        if (!find_slot(ubatch)) {
+            success = false;
+            break;
+        }
+    }
+
+    // restore the original state
+    cells = std::move(org_cells);
+    used = org_used;
+    head = org_head;
+
+    return success;
+}
+
+bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
+    const uint32_t n_seqs = ubatch.n_seqs;
+
+    const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    // if we have enough unused cells before the current head ->
+    //   better to start searching from the beginning of the cache, hoping to fill it
+    if (head > used + 2*n_seqs) {
+        head = 0;
+    }
+
+    // For recurrent state architectures (like Mamba or RWKV),
+    // each cache cell can store the state for a whole sequence.
+    // A slot should be always be contiguous.
+
+    // can only process batches with an equal number of new tokens in each sequence
+    GGML_ASSERT(ubatch.equal_seqs);
+
+    int32_t min = size - 1;
+    int32_t max = 0;
+
+    // everything should fit if all seq_ids are smaller than the max
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const uint32_t n_seq_id = ubatch.n_seq_id[s];
+        for (uint32_t j = 0; j < n_seq_id; ++j) {
+            const llama_seq_id seq_id = ubatch.seq_id[s][j];
+
+            if (seq_id < 0 || (uint32_t) seq_id >= size) {
+                // too big seq_id
+                // TODO: would it be possible to resize the cache instead?
+                LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
+                return false;
+            }
+            if (j > 0) {
+                auto & seq = cells[seq_id];
+                if (seq.tail >= 0) {
+                    auto & cell = cells[seq.tail];
+                    // clear cells from seq_ids that become shared
+                    // (should not normally happen, but let's handle it anyway)
+                    cell.seq_id.erase(seq_id);
+                    seq.tail = -1;
+                    if (cell.seq_id.empty()) {
+                        cell.pos = -1;
+                        cell.src = -1;
+                        used -= 1;
+                    }
+                }
+            }
+        }
+    }
+
+#ifndef NDEBUG
+    {
+        std::vector<int32_t> tails_verif;
+        tails_verif.assign(size, -1);
+        for (uint32_t i = 0; i < size; ++i) {
+            auto & cell = cells[i];
+            for (llama_seq_id seq_id : cell.seq_id) {
+                if (tails_verif[seq_id] != -1) {
+                    LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
+                }
+                tails_verif[seq_id] = i;
+            }
+        }
+        for (uint32_t i = 0; i < size; ++i) {
+            if (tails_verif[i] != cells[i].tail) {
+                LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]);
+            }
+        }
+    }
+#endif
+
+    // find next empty cell
+    uint32_t next_empty_cell = head;
+
+    for (uint32_t i = 0; i < size; ++i) {
+        if (next_empty_cell >= size) { next_empty_cell -= size; }
+        auto & cell = cells[next_empty_cell];
+        if (cell.is_empty()) { break; }
+        next_empty_cell += 1;
+    }
+
+    // find usable cell range
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const llama_seq_id seq_id = ubatch.seq_id[s][0];
+        auto & seq_meta = cells[seq_id];
+        bool has_cell = false;
+        if (seq_meta.tail >= 0) {
+            auto & cell = cells[seq_meta.tail];
+            GGML_ASSERT(cell.has_seq_id(seq_id));
+            // does this seq_id "own" the cell?
+            if (cell.seq_id.size() == 1) { has_cell = true; }
+        }
+        if (!has_cell) {
+            auto & empty_cell = cells[next_empty_cell];
+            GGML_ASSERT(empty_cell.is_empty());
+            // copy old tail into the empty cell
+            if (seq_meta.tail >= 0) {
+                auto & orig_cell = cells[seq_meta.tail];
+                empty_cell.pos = orig_cell.pos;
+                empty_cell.src = orig_cell.src;
+                orig_cell.seq_id.erase(seq_id);
+                empty_cell.seq_id.insert(seq_id); // will be overwritten
+                GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
+            }
+            seq_meta.tail = next_empty_cell;
+            // find next empty cell
+            if (s + 1 < n_seqs) {
+                for (uint32_t i = 0; i < size; ++i) {
+                    next_empty_cell += 1;
+                    if (next_empty_cell >= size) { next_empty_cell -= size; }
+                    auto & cell = cells[next_empty_cell];
+                    if (cell.is_empty()) { break; }
+                }
+            }
+        }
+        if (min > seq_meta.tail) { min = seq_meta.tail; }
+        if (max < seq_meta.tail) { max = seq_meta.tail; }
+    }
+
+    // gather and re-order
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const int32_t dst_id = s + min;
+        const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
+        if (dst_id != src_id) {
+            auto & dst_cell = cells[dst_id];
+            auto & src_cell = cells[src_id];
+
+            std::swap(dst_cell.pos, src_cell.pos);
+            std::swap(dst_cell.src, src_cell.src);
+            std::swap(dst_cell.seq_id, src_cell.seq_id);
+
+            // swap tails
+            for (uint32_t i = 0; i < size; ++i) {
+                int32_t & tail = cells[i].tail;
+                if (tail == src_id) {
+                    tail = dst_id;
+                } else if (tail == dst_id) {
+                    tail = src_id;
+                }
+            }
+        }
+    }
+
+    // update the pos of the used seqs
+    for (uint32_t s = 0; s < n_seqs; ++s) {
+        const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
+        const int32_t cell_id = s + min;
+        auto & cell = cells[cell_id];
+
+        if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
+            // What should happen when the pos backtracks or skips a value?
+            // Clearing the state mid-batch would require special-casing which isn't done.
+            LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
+                __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens);
+        }
+        cell.pos = last_pos;
+        cell.seq_id.clear();
+        for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) {
+            const llama_seq_id seq_id = ubatch.seq_id[s][j];
+            cell.seq_id.insert(seq_id);
+            cells[seq_id].tail = cell_id;
+        }
+    }
+
+    // Find first cell without src refs, to use as the zero-ed state
+    {
+        // TODO: bake-in src refcounts in the cell metadata
+        std::vector<int32_t> refcounts(size, 0);
+        for (size_t i = 0; i < size; ++i) {
+            const int32_t src = cells[i].src;
+            if (src >= 0) {
+                refcounts[src] += 1;
+            }
+        }
+
+        rs_z = -1;
+        for (int i = min; i <= max; ++i) {
+            if (refcounts[i] == 0) {
+                rs_z = i;
+                break;
+            }
+        }
+
+        for (int i = min; i <= max; ++i) {
+            if (cells[i].src < 0) {
+                GGML_ASSERT(rs_z >= 0);
+                cells[i].src0 = rs_z;
+            } else {
+                // Stage the source ids for all used cells to allow correct seq_* behavior
+                // and still make these values available when setting the inputs
+                cells[i].src0 = cells[i].src;
+            }
+            cells[i].src = i; // avoid moving or clearing twice
+        }
+    }
+
+    // allow getting the range of used cells, from head to head + n
+    head = min;
+    n    = max - min + 1;
+    used = std::count_if(cells.begin(), cells.end(),
+        [](const mem_cell & cell){ return !cell.is_empty(); });
+
+    // sanity check
+    return n >= n_seqs;
+}
+
+bool llama_memory_recurrent::get_can_shift() const {
+    // shifting the pos is trivial for recurrent models
+    return true;
+}
+
+size_t llama_memory_recurrent::total_size() const {
+    size_t size = 0;
+    for (const auto & buf : bufs) {
+        size += ggml_backend_buffer_get_size(buf.get());
+    }
+
+    return size;
+}
+
+size_t llama_memory_recurrent::size_r_bytes() const {
+    size_t size_r_bytes = 0;
+
+    for (const auto & r : r_l) {
+        if (r != nullptr) {
+            size_r_bytes += ggml_nbytes(r);
+        }
+    }
+
+    return size_r_bytes;
+}
+
+size_t llama_memory_recurrent::size_s_bytes() const {
+    size_t size_s_bytes = 0;
+
+    for (const auto & s : s_l) {
+        if (s != nullptr) {
+            size_s_bytes += ggml_nbytes(s);
+        }
+    }
+
+    return size_s_bytes;
+}
+
+void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
+    std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
+    uint32_t cell_count = 0;
+
+    // Count the number of cells with the specified seq_id
+    // Find all the ranges of cells with this seq id (or all, when -1)
+    uint32_t cell_range_begin = size;
+    for (uint32_t i = 0; i < size; ++i) {
+        const auto & cell = cells[i];
+        if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+            ++cell_count;
+            if (cell_range_begin == size) {
+                cell_range_begin = i;
+            }
+        } else {
+            if (cell_range_begin != size) {
+                cell_ranges.emplace_back(cell_range_begin, i);
+                cell_range_begin = size;
+            }
+        }
+    }
+    if (cell_range_begin != size) {
+        cell_ranges.emplace_back(cell_range_begin, size);
+    }
+
+    // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+    uint32_t cell_count_check = 0;
+    for (const auto & range : cell_ranges) {
+        cell_count_check += range.second - range.first;
+    }
+    GGML_ASSERT(cell_count == cell_count_check);
+
+    io.write(&cell_count, sizeof(cell_count));
+
+    state_write_meta(io, cell_ranges, seq_id);
+    state_write_data(io, cell_ranges);
+}
+
+void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
+    uint32_t cell_count;
+    io.read_to(&cell_count, sizeof(cell_count));
+
+    bool res = true;
+
+    res = res && state_read_meta(io, cell_count, seq_id);
+    res = res && state_read_data(io, cell_count);
+
+    if (!res) {
+        if (seq_id == -1) {
+            clear(true);
+        } else {
+            seq_rm(seq_id, -1, -1);
+        }
+        throw std::runtime_error("failed to restore kv cache");
+    }
+}
+
+void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
+    for (const auto & range : cell_ranges) {
+        for (uint32_t i = range.first; i < range.second; ++i) {
+            const auto & cell = cells[i];
+            const llama_pos pos      = cell.pos;
+            const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+
+            io.write(&pos,      sizeof(pos));
+            io.write(&n_seq_id, sizeof(n_seq_id));
+
+            if (n_seq_id) {
+                for (auto seq_id : cell.seq_id) {
+                    io.write(&seq_id, sizeof(seq_id));
+                }
+            }
+        }
+    }
+}
+
+void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
+    const uint32_t s_trans = 0;
+    const uint32_t n_layer = hparams.n_layer;
+
+    io.write(&s_trans, sizeof(s_trans));
+    io.write(&n_layer,   sizeof(n_layer));
+
+    std::vector<uint8_t> tmp_buf;
+
+    // Iterate and write all the keys first, each row is a cell
+    // Get whole range at a time
+    for (uint32_t il = 0; il < n_layer; ++il) {
+
+        // Write key type
+        const int32_t r_type_i = (int32_t)r_l[il]->type;
+        io.write(&r_type_i, sizeof(r_type_i));
+
+        // Write row size of key
+        const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
+        io.write(&r_size_row, sizeof(r_size_row));
+
+        // Read each range of cells of k_size length each into tmp_buf and write out
+        for (const auto & range : cell_ranges) {
+            const size_t range_size = range.second - range.first;
+            const size_t buf_size = range_size * r_size_row;
+            io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
+        }
+    }
+
+    if (!s_trans) {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+
+            // Write value type
+            const int32_t s_type_i = (int32_t)s_l[il]->type;
+            io.write(&s_type_i, sizeof(s_type_i));
+
+            // Write row size of value
+            const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
+            io.write(&s_size_row, sizeof(s_size_row));
+
+            // Read each range of cells of s_size length each into tmp_buf and write out
+            for (const auto & range : cell_ranges) {
+                const size_t range_size = range.second - range.first;
+                const size_t buf_size = range_size * s_size_row;
+                io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
+            }
+        }
+    } else {
+        // When v is transposed, we also need the element size and get the element ranges from each row
+        const uint32_t mem_size = size;
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_s = hparams.n_embd_s();
+
+            // Write value type
+            const int32_t s_type_i = (int32_t)s_l[il]->type;
+            io.write(&s_type_i, sizeof(s_type_i));
+
+            // Write element size
+            const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
+            io.write(&s_size_el, sizeof(s_size_el));
+
+            // Write GQA embedding size
+            io.write(&n_embd_s, sizeof(n_embd_s));
+
+            // For each row, we get the element values of each cell
+            for (uint32_t j = 0; j < n_embd_s; ++j) {
+                // Read each range of cells of v_size_el length each into tmp_buf and write out
+                for (const auto & range : cell_ranges) {
+                    const size_t range_size = range.second - range.first;
+                    const size_t src_offset = (range.first + j * mem_size) * s_size_el;
+                    const size_t buf_size = range_size * s_size_el;
+                    io.write_tensor(s_l[il], src_offset, buf_size);
+                }
+            }
+        }
+    }
+}
+
+bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
+    if (dest_seq_id != -1) {
+        // single sequence
+
+        seq_rm(dest_seq_id, -1, -1);
+
+        llama_sbatch sbatch;
+        llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
+
+        batch.n_tokens = cell_count;
+        batch.n_seq_tokens = cell_count;
+        batch.n_seqs = 1;
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            llama_pos pos;
+            uint32_t n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            if (n_seq_id != 0) {
+                LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+                return false;
+            }
+
+            batch.pos[i] = pos;
+        }
+        batch.n_seq_id[0] = 1;
+        batch.seq_id[0] = &dest_seq_id;
+
+        if (!find_slot(batch)) {
+            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+            return false;
+        }
+
+        // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
+        // Assume that this is one contiguous block of cells
+        GGML_ASSERT(head + cell_count <= size);
+        GGML_ASSERT(cells[head].pos == batch.pos[0]);
+        GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
+        GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
+        GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
+    } else {
+        // whole KV cache restore
+
+        if (cell_count > size) {
+            LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+            return false;
+        }
+
+        clear(true);
+
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            auto & cell = cells[i];
+
+            llama_pos pos;
+            uint32_t  n_seq_id;
+
+            io.read_to(&pos,      sizeof(pos));
+            io.read_to(&n_seq_id, sizeof(n_seq_id));
+
+            cell.pos = pos;
+
+            for (uint32_t j = 0; j < n_seq_id; ++j) {
+                llama_seq_id seq_id;
+                io.read_to(&seq_id, sizeof(seq_id));
+
+                // TODO: llama_memory_recurrent should have a notion of max sequences
+                //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
+                if (seq_id < 0) {
+                    //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
+                    LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
+                    return false;
+                }
+
+                cell.seq_id.insert(seq_id);
+
+                int32_t & tail = cells[seq_id].tail;
+                if (tail != -1) {
+                    LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
+                    return false;
+                }
+                tail = i;
+            }
+        }
+
+        head = 0;
+        used = cell_count;
+    }
+
+    for (uint32_t i = 0; i < cell_count; ++i) {
+        uint32_t cell_id = head + i;
+        // make sure the recurrent states will keep their restored state
+        cells[cell_id].src = cell_id;
+    }
+
+    return true;
+}
+
+bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
+    uint32_t s_trans;
+    uint32_t n_layer;
+    io.read_to(&s_trans, sizeof(s_trans));
+    io.read_to(&n_layer, sizeof(n_layer));
+
+    if (n_layer != hparams.n_layer) {
+        LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
+        return false;
+    }
+    if (cell_count > size) {
+        LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
+        return false;
+    }
+    if (false != (bool) s_trans) {
+        LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
+        return false;
+    }
+
+    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+    for (uint32_t il = 0; il < n_layer; ++il) {
+
+        // Read type of key
+        int32_t r_type_i_ref;
+        io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
+        const int32_t r_type_i = (int32_t) r_l[il]->type;
+        if (r_type_i != r_type_i_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
+            return false;
+        }
+
+        // Read row size of key
+        uint64_t r_size_row_ref;
+        io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
+        const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
+        if (r_size_row != r_size_row_ref) {
+            LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
+            return false;
+        }
+
+        if (cell_count) {
+            // Read and set the keys for the whole cell range
+            ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
+        }
+    }
+
+    if (!s_trans) {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+
+            // Read type of value
+            int32_t s_type_i_ref;
+            io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
+            const int32_t s_type_i = (int32_t)s_l[il]->type;
+            if (s_type_i != s_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
+                return false;
+            }
+
+            // Read row size of value
+            uint64_t s_size_row_ref;
+            io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
+            const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
+            if (s_size_row != s_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // Read and set the values for the whole cell range
+                ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
+            }
+        }
+    } else {
+        // For each layer, read the values for each cell (transposed)
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_s = hparams.n_embd_s();
+
+            // Read type of value
+            int32_t s_type_i_ref;
+            io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
+            const int32_t s_type_i = (int32_t)s_l[il]->type;
+            if (s_type_i != s_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
+                return false;
+            }
+
+            // Read element size of value
+            uint32_t s_size_el_ref;
+            io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
+            const size_t s_size_el = ggml_type_size(s_l[il]->type);
+            if (s_size_el != s_size_el_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
+                return false;
+            }
+
+            // Read state embedding size
+            uint32_t n_embd_s_ref;
+            io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
+            if (n_embd_s != n_embd_s_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // For each row in the transposed matrix, read the values for the whole cell range
+                for (uint32_t j = 0; j < n_embd_s; ++j) {
+                    const size_t dst_offset = (head + j * size) * s_size_el;
+                    ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
+                }
+            }
+        }
+    }
+
+    return true;
+}
+
+//
+// llama_memory_recurrent_state
+//
+
+llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
+
+llama_memory_recurrent_state::llama_memory_recurrent_state(
+        llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
+}
+
+llama_memory_recurrent_state::llama_memory_recurrent_state(
+        llama_memory_recurrent * mem,
+        llama_sbatch sbatch,
+        std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
+
+llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
+
+bool llama_memory_recurrent_state::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    if (++i_next >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_memory_recurrent_state::apply() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    mem->find_slot(ubatches[i_next]);
+
+    return true;
+}
+
+std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return sbatch.out_ids;
+}
+
+llama_memory_status llama_memory_recurrent_state::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    return ubatches[i_next];
+}
+
+uint32_t llama_memory_recurrent_state::get_n_rs() const {
+    return is_full ? mem->size : mem->n;
+}
+
+uint32_t llama_memory_recurrent_state::get_head() const {
+    return is_full ? 0 : mem->head;
+}
+
+int32_t llama_memory_recurrent_state::get_rs_z() const {
+    return is_full ? 0 : mem->rs_z;
+}
+
+uint32_t llama_memory_recurrent_state::get_size() const {
+    return mem->size;
+}
+
+ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
+    return mem->r_l[il];
+}
+
+ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
+    return mem->s_l[il];
+}
+
+int32_t llama_memory_recurrent_state::s_copy(int i) const {
+    return  mem->cells[i + mem->head].src0;
+}
diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h
new file mode 100644 (file)
index 0000000..290cc84
--- /dev/null
@@ -0,0 +1,188 @@
+#pragma once
+
+#include "llama-batch.h"
+#include "llama-graph.h"
+#include "llama-memory.h"
+
+#include <set>
+#include <vector>
+
+//
+// llama_memory_recurrent
+//
+
+// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
+//       see the implementation of llama_kv_cache_unified_state_i for an example how to do it
+class llama_memory_recurrent : public llama_memory_i {
+public:
+
+    // this callback is used to filter out layers that should not be included in the cache
+    using layer_filter_cb = std::function<bool(int32_t il)>;
+
+    llama_memory_recurrent(
+            const llama_model &  model,
+              layer_filter_cb && filter,
+                    ggml_type    type_r,
+                    ggml_type    type_s,
+                         bool    offload,
+                     uint32_t    mem_size,
+                     uint32_t    n_seq_max);
+
+    ~llama_memory_recurrent() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    llama_memory_state_ptr init_batch(
+            const llama_batch & batch,
+            uint32_t n_ubatch,
+            bool embd_all) override;
+
+    llama_memory_state_ptr init_full() override;
+
+    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
+
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    bool prepare(const std::vector<llama_ubatch> & ubatches);
+
+    // find a contiguous slot of memory cells and emplace the ubatch there
+    bool find_slot(const llama_ubatch & ubatch);
+
+    bool get_can_shift() const override;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
+
+    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
+    uint32_t size = 0; // total number of cells, shared across all sequences
+    uint32_t used = 0; // used cells (i.e. at least one seq_id)
+
+    // computed before each graph build
+    uint32_t n = 0;
+
+    // first zero-ed state
+    int32_t rs_z = -1;
+
+    // TODO: optimize for recurrent state needs
+    struct mem_cell {
+        llama_pos pos  = -1;
+        int32_t   src  = -1; // used to know where states should be copied from
+        int32_t   src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
+        int32_t   tail = -1;
+
+        std::set<llama_seq_id> seq_id;
+
+        bool has_seq_id(const llama_seq_id & id) const {
+            return seq_id.find(id) != seq_id.end();
+        }
+
+        bool is_empty() const {
+            return seq_id.empty();
+        }
+
+        bool is_same_seq(const mem_cell & other) const {
+            return seq_id == other.seq_id;
+        }
+    };
+
+    std::vector<mem_cell> cells;
+
+    // per layer
+    std::vector<ggml_tensor *> r_l;
+    std::vector<ggml_tensor *> s_l;
+
+private:
+    //const llama_model & model;
+    const llama_hparams & hparams;
+
+    const uint32_t n_seq_max = 1;
+
+    std::vector<ggml_context_ptr>        ctxs;
+    std::vector<ggml_backend_buffer_ptr> bufs;
+
+    size_t total_size() const;
+
+    size_t size_r_bytes() const;
+    size_t size_s_bytes() const;
+
+    void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
+    void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
+
+    bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
+};
+
+class llama_memory_recurrent_state : public llama_memory_state_i {
+public:
+    // used for errors
+    llama_memory_recurrent_state(llama_memory_status status);
+
+    // used to create a full-cache state
+    llama_memory_recurrent_state(
+            llama_memory_recurrent * mem);
+
+    // used to create a state from a batch
+    llama_memory_recurrent_state(
+            llama_memory_recurrent * mem,
+            llama_sbatch sbatch,
+            std::vector<llama_ubatch> ubatches);
+
+    virtual ~llama_memory_recurrent_state();
+
+    //
+    // llama_memory_state_i
+    //
+
+    bool next()  override;
+    bool apply() override;
+
+    std::vector<int64_t> & out_ids() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_memory_recurrent_state specific API
+    //
+
+    uint32_t get_n_rs() const;
+    uint32_t get_head() const;
+    int32_t  get_rs_z() const;
+    uint32_t get_size() const;
+
+    ggml_tensor * get_r_l(int32_t il) const;
+    ggml_tensor * get_s_l(int32_t il) const;
+
+    int32_t s_copy(int i) const;
+
+private:
+    const llama_memory_status status;
+
+    llama_memory_recurrent * mem;
+
+    llama_sbatch sbatch;
+
+    size_t i_next = 0;
+
+    std::vector<llama_ubatch> ubatches;
+
+    //
+    // data needed for building the compute graph for the current ubatch:
+    // TODO: extract all the state like `head` and `n` here
+    //
+
+    const bool is_full = false;
+};
index a5eb122f998d85cbcd309953dd71518271578ceb..a5853f8b12dc0a692ab89dda7e20d064921dd90b 100644 (file)
@@ -8,7 +8,8 @@
 
 #include "llama-kv-cache-unified.h"
 #include "llama-kv-cache-unified-iswa.h"
-#include "llama-kv-cache-recurrent.h"
+#include "llama-memory-hybrid.h"
+#include "llama-memory-recurrent.h"
 
 #include "ggml-cpp.h"
 
@@ -470,6 +471,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
     std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
     std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
     std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
+    std::fill(
+        hparams.recurrent_layer_arr.begin(),
+        hparams.recurrent_layer_arr.end(),
+        llm_arch_is_recurrent(ml.get_arch()));
 
     std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
 
@@ -9111,7 +9116,7 @@ struct llm_build_mamba : public llm_graph_context {
         // {n_embd, n_tokens}
         inpL = build_inp_embd(model.tok_embd);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         for (int il = 0; il < n_layer; ++il) {
             // norm
@@ -9120,7 +9125,7 @@ struct llm_build_mamba : public llm_graph_context {
                     LLM_NORM_RMS, il);
             cb(cur, "attn_norm", il);
 
-            cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
+            cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
 
             if (il == n_layer - 1) {
                 // skip computing output for unused tokens
@@ -9158,12 +9163,12 @@ struct llm_build_mamba : public llm_graph_context {
 
     // TODO: split
     ggml_tensor * build_mamba_layer(
-             ggml_cgraph * gf,
-             ggml_tensor * cur,
-             ggml_tensor * state_copy,
-      const llama_ubatch & ubatch,
-                     int   il) const {
-        const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+        llm_graph_input_rs * inp,
+               ggml_cgraph * gf,
+               ggml_tensor * cur,
+        const llama_ubatch & ubatch,
+                       int   il) const {
+        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
         const auto kv_head = kv_state->get_head();
 
@@ -9183,17 +9188,17 @@ struct llm_build_mamba : public llm_graph_context {
         GGML_ASSERT(ubatch.equal_seqs);
         GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
-        ggml_tensor * conv_states_all = kv_state->get_k_l(il);
-        ggml_tensor * ssm_states_all  = kv_state->get_v_l(il);
+        ggml_tensor * conv_states_all = kv_state->get_r_l(il);
+        ggml_tensor * ssm_states_all  = kv_state->get_s_l(il);
 
         // (ab)using the KV cache to store the states
-        ggml_tensor * conv = build_recurrent_state(
-                gf, conv_states_all, state_copy,
-                hparams.n_embd_k_s(), n_seqs);
+        ggml_tensor * conv = build_rs(
+                inp, gf, conv_states_all,
+                hparams.n_embd_r(), n_seqs);
         conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
-        ggml_tensor * ssm = build_recurrent_state(
-                gf, ssm_states_all, state_copy,
-                hparams.n_embd_v_s(), n_seqs);
+        ggml_tensor * ssm = build_rs(
+                inp, gf, ssm_states_all,
+                hparams.n_embd_s(), n_seqs);
         ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
 
         // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@@ -11904,13 +11909,13 @@ struct llm_build_rwkv6_base : public llm_graph_context {
     }
 
     ggml_tensor * build_rwkv6_time_mix(
+            llm_graph_input_rs * inp,
             ggml_cgraph * gf,
             ggml_tensor * cur,
             ggml_tensor * x_prev,
-            ggml_tensor * state_copy,
             const llama_ubatch & ubatch,
             int   il) const {
-        const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -12031,9 +12036,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
         }
 
-        ggml_tensor * wkv_state = build_recurrent_state(
-                gf, kv_state->get_v_l(il), state_copy,
-                hparams.n_embd_v_s(), n_seqs);
+        ggml_tensor * wkv_state = build_rs(
+                inp, gf, kv_state->get_s_l(il),
+                hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output;
         if (is_qrwkv) {
@@ -12051,9 +12056,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_state->get_v_l(il),
-                        hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
+                        kv_state->get_s_l(il),
+                        hparams.n_embd_s() * n_seqs,
+                        hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
                         )
                     )
                 );
@@ -12087,7 +12092,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
         inpL = build_inp_embd(model.tok_embd);
         inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12097,9 +12102,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, ubatch, il
-                    );
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
             ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -12114,7 +12117,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
+            cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -12177,14 +12180,14 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
 // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
 struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
     llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
-        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+        GGML_ASSERT(n_embd == hparams.n_embd_r());
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
 
         inpL = build_inp_embd(model.tok_embd);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12194,9 +12197,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, ubatch, il
-                    );
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
             cb(att_norm, "attn_norm", il);
@@ -12208,7 +12209,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
+            cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12296,14 +12297,14 @@ struct llm_build_rwkv7_base : public llm_graph_context {
     }
 
     ggml_tensor * build_rwkv7_time_mix(
+            llm_graph_input_rs * inp,
             ggml_cgraph * gf,
             ggml_tensor * cur,
             ggml_tensor * x_prev,
-            ggml_tensor * state_copy,
             ggml_tensor *& first_layer_value,
             const llama_ubatch & ubatch,
             int   il) const {
-        const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
+        const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
 
         const auto n_tokens = ubatch.n_tokens;
         const auto n_seqs = ubatch.n_seqs;
@@ -12382,9 +12383,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
         a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
 
-        ggml_tensor * wkv_state = build_recurrent_state(
-                gf, kv_state->get_v_l(il), state_copy,
-                hparams.n_embd_v_s(), n_seqs);
+        ggml_tensor * wkv_state = build_rs(
+                inp, gf, kv_state->get_s_l(il),
+                hparams.n_embd_s(), n_seqs);
 
         ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
         cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
@@ -12397,9 +12398,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
                     wkv_state,
                     ggml_view_1d(
                         ctx0,
-                        kv_state->get_v_l(il),
-                        hparams.n_embd_v_s() * n_seqs,
-                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
+                        kv_state->get_s_l(il),
+                        hparams.n_embd_s() * n_seqs,
+                        hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
                         )
                     )
                 );
@@ -12440,7 +12441,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
         inpL = build_inp_embd(model.tok_embd);
         inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12450,9 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, ubatch, il
-                    );
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
             ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@@ -12467,7 +12466,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -12525,7 +12524,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
 
 struct llm_build_arwkv7 : public llm_build_rwkv7_base {
     llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
-        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+        GGML_ASSERT(n_embd == hparams.n_embd_r());
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
@@ -12533,7 +12532,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        ggml_tensor * state_copy = build_inp_s_copy();
+        auto * rs_inp = build_rs_inp();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12543,9 +12542,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
             const llama_layer * layer = &model.layers[il];
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
-            ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, ubatch, il
-                    );
+            ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
             cb(att_norm, "attn_norm", il);
@@ -12557,7 +12554,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -13738,6 +13735,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
     llama_memory_i * res;
 
     switch (arch) {
+        // Models that need specific instantiation should be handled in the
+        // switch statement
         case LLM_ARCH_BERT:
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_NOMIC_BERT:
@@ -13747,57 +13746,75 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
             {
                 res = nullptr;
             } break;
-        case LLM_ARCH_MAMBA:
-        case LLM_ARCH_RWKV6:
-        case LLM_ARCH_RWKV6QWEN2:
-        case LLM_ARCH_RWKV7:
-        case LLM_ARCH_ARWKV7:
-            {
-                res = new llama_kv_cache_recurrent(
-                        *this,
-                        GGML_TYPE_F32,
-                        GGML_TYPE_F32,
-                        cparams.offload_kqv,
-                        std::max((uint32_t) 1, cparams.n_seq_max),
-                        cparams.n_seq_max);
-            } break;
+        // Models that need standard caching should rely on recurrent/hybrid
+        // checks
         default:
             {
-                const auto padding = llama_kv_cache_unified::get_padding(cparams);
-
-                cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
-
-                LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
-
-                if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
-                    GGML_ASSERT(hparams.is_swa_any());
-
-                    res = new llama_kv_cache_unified_iswa(
-                            *this,
-                            params.type_k,
-                            params.type_v,
-                            !cparams.flash_attn,
-                            cparams.offload_kqv,
-                            params.swa_full,
-                            cparams.n_ctx,
-                            cparams.n_seq_max,
-                            cparams.n_ubatch,
-                            padding);
-                } else {
-                    GGML_ASSERT(!hparams.is_swa_any());
-
-                    res = new llama_kv_cache_unified(
+                if (llm_arch_is_recurrent(arch)) {
+                    res = new llama_memory_recurrent(
                             *this,
                             nullptr,
-                            params.type_k,
-                            params.type_v,
-                            !cparams.flash_attn,
+                            GGML_TYPE_F32,
+                            GGML_TYPE_F32,
                             cparams.offload_kqv,
-                            cparams.n_ctx,
-                            cparams.n_seq_max,
-                            padding,
-                            hparams.n_swa,
-                            hparams.swa_type);
+                            std::max((uint32_t) 1, cparams.n_seq_max),
+                            cparams.n_seq_max);
+                } else if (llm_arch_is_hybrid(arch)) {
+                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+
+                    cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
+
+                    res = new llama_memory_hybrid(
+                        /* model             */ *this,
+                        /* attn_type_k       */ params.type_k,
+                        /* attn_type_v       */ params.type_v,
+                        /* attn_v_trans      */ !cparams.flash_attn,
+                        /* attn_kv_size      */ cparams.n_ctx,
+                        /* attn_n_pad        */ padding,
+                        /* attn_n_swa        */ hparams.n_swa,
+                        /* attn_swa_type     */ hparams.swa_type,
+                        /* recurrent_type_k  */ GGML_TYPE_F32,
+                        /* recurrent_type_v  */ GGML_TYPE_F32,
+                        /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
+                        /* n_seq_max         */ cparams.n_seq_max,
+                        /* offload           */ cparams.offload_kqv);
+                } else {
+                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+
+                    cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
+
+                    LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
+
+                    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+                        GGML_ASSERT(hparams.is_swa_any());
+
+                        res = new llama_kv_cache_unified_iswa(
+                                *this,
+                                params.type_k,
+                                params.type_v,
+                                !cparams.flash_attn,
+                                cparams.offload_kqv,
+                                params.swa_full,
+                                cparams.n_ctx,
+                                cparams.n_seq_max,
+                                cparams.n_ubatch,
+                                padding);
+                    } else {
+                        GGML_ASSERT(!hparams.is_swa_any());
+
+                        res = new llama_kv_cache_unified(
+                                *this,
+                                nullptr,
+                                params.type_k,
+                                params.type_v,
+                                !cparams.flash_attn,
+                                cparams.offload_kqv,
+                                cparams.n_ctx,
+                                cparams.n_seq_max,
+                                padding,
+                                hparams.n_swa,
+                                hparams.swa_type);
+                    }
                 }
             }
     }
@@ -14377,14 +14394,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
 }
 
 bool llama_model_is_recurrent(const llama_model * model) {
-    switch (model->arch) {
-        case     LLM_ARCH_MAMBA:      return true;
-        case     LLM_ARCH_RWKV6:      return true;
-        case     LLM_ARCH_RWKV6QWEN2: return true;
-        case     LLM_ARCH_RWKV7:      return true;
-        case     LLM_ARCH_ARWKV7:     return true;
-        default: return false;
-    }
+    return llm_arch_is_recurrent(model->arch);
 }
 
 const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {