]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kv-cache : avoid modifying recurrent cells when setting inputs (#13834)
authorcompilade <redacted>
Tue, 10 Jun 2025 22:20:14 +0000 (18:20 -0400)
committerGitHub <redacted>
Tue, 10 Jun 2025 22:20:14 +0000 (18:20 -0400)
* kv-cache : avoid modifying recurrent cells when setting inputs

* kv-cache : remove inp_s_mask

It was replaced with equivalent and simpler functionality
with rs_z (the first zeroed state) and the already-existing inp_s_copy.

* kv-cache : fix non-consecutive token pos warning for recurrent models

The problem was apparently caused by how the tail cells were swapped.

* graph : simplify logic for recurrent state copies

* kv-cache : use cell without src refs for rs_z in recurrent cache

* llama-graph : fix recurrent state copy

The `state_copy` shuffle assumes everything is moved at once,
which is not true when `states_extra` is copied back to the cache
before copying the range of states between `head` and `head + n_seqs`.
This is only a problem if any of the cells in [`head`, `head + n_seqs`)
have an `src` in [`head + n_seqs`, `head + n_kv`),
which does happen when `n_ubatch > 1` in the `llama-parallel` example.

Changing the order of the operations avoids the potential overwrite
before use, although when copies are avoided (like with Mamba2),
this will require further changes.

* llama-graph : rename n_state to state_size in build_recurrent_state

This naming should reduce confusion between the state size
and the number of states.

src/llama-graph.cpp
src/llama-graph.h
src/llama-kv-cache-recurrent.cpp
src/llama-kv-cache-recurrent.h
src/llama-kv-cache-unified.cpp
src/llama-model.cpp

index 56082279119d89e2f97ee639c66d8ad89afbc4cf..e74c9ff53b05a47f12212ecbbcd3d618c09ed78a 100644 (file)
@@ -250,22 +250,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
-    GGML_UNUSED(ubatch);
-
-    const int64_t n_kv = kv_state->get_n_kv();
-
-    if (s_mask) {
-        GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
-        float * data = (float *) s_mask->data;
-
-        // clear unused states
-        for (int i = 0; i < n_kv; ++i) {
-            data[i] = kv_state->s_mask(i);
-        }
-    }
-}
-
 void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
     GGML_UNUSED(ubatch);
 
@@ -987,23 +971,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
     return cur;
 }
 
-ggml_tensor * llm_graph_context::build_inp_s_mask() const {
-    const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
-
-    auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
-
-    const auto n_kv = kv_state->get_n_kv();
-
-    auto & cur = inp->s_mask;
-
-    cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
-    ggml_set_input(cur);
-
-    res->add_input(std::move(inp));
-
-    return cur;
-}
-
 ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
     auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
 
@@ -1456,43 +1423,53 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-ggml_tensor * llm_graph_context::build_copy_mask_state(
+ggml_tensor * llm_graph_context::build_recurrent_state(
          ggml_cgraph * gf,
          ggml_tensor * s,
          ggml_tensor * state_copy,
-         ggml_tensor * state_mask,
-             int32_t   n_state,
-             int32_t   n_seqs) const {
+             int32_t   state_size,
+             int32_t   n_seqs,
+                bool   avoid_copies) const {
     const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
 
     const auto n_kv    = kv_state->get_n_kv();
     const auto kv_head = kv_state->get_head();
+    const auto rs_zero = kv_state->get_rs_z();
 
-    ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
+    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
 
-    // copy states
-    // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
-    // this shrinks the tensors's ne[1] to n_kv
-    states = ggml_get_rows(ctx0, states, state_copy);
+    // Clear a single state which will then be copied to the other cleared states.
+    // Note that this is a no-op when the view is zero-sized.
+    ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
+    ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
 
-    // clear states of sequences which are starting at the beginning of this batch
-    // FIXME: zero-out NANs?
-    states = ggml_mul(ctx0, states, state_mask);
+    ggml_tensor * output_states;
+
+    if (!avoid_copies) {
+        // copy states
+        // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
+        // {state_size, kv_size} -> {state_size, n_seqs}
+        output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
+        ggml_build_forward_expand(gf, output_states);
+    } else {
+        // FIXME: make the gathering operation happen before the copy below
+        //        (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
+        output_states = states;
+    }
 
-    // copy states which won't be changed further (between n_seqs and n_kv)
+    // copy extra states which won't be changed further (between n_seqs and n_kv)
+    ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
     ggml_build_forward_expand(gf,
         ggml_cpy(ctx0,
-            ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs          )*n_state*ggml_element_size(states)),
-            ggml_view_1d(ctx0, s,      n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
+            states_extra,
+            ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
 
-    // the part of the states that will be used and modified
-    return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
+    return output_states;
 }
 
 ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
          ggml_cgraph * gf,
          ggml_tensor * state_copy,
-         ggml_tensor * state_mask,
   const llama_ubatch & ubatch,
                  int   il) const {
     const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -1503,8 +1480,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
 
     ggml_tensor * token_shift_all = kv_state->get_k_l(il);
 
-    ggml_tensor * token_shift = build_copy_mask_state(
-            gf, token_shift_all, state_copy, state_mask,
+    ggml_tensor * token_shift = build_recurrent_state(
+            gf, token_shift_all, state_copy,
             hparams.n_embd_k_s(), n_seqs);
 
     token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
index 28da6a5228bdcbb928d2d16c8da1e2a6bce7b027..88fb77f1ddc9a330a8353ec19b161574d455e909 100644 (file)
@@ -200,18 +200,6 @@ public:
     const llama_kv_cache_recurrent_state * kv_state;
 };
 
-class llm_graph_input_s_mask : public llm_graph_input_i {
-public:
-    llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
-    virtual ~llm_graph_input_s_mask() = default;
-
-    void set_input(const llama_ubatch * ubatch) override;
-
-    ggml_tensor * s_mask; // F32 [1, n_kv]
-
-    const llama_kv_cache_recurrent_state * kv_state;
-};
-
 class llm_graph_input_cross_embd : public llm_graph_input_i {
 public:
     llm_graph_input_cross_embd(
@@ -521,7 +509,6 @@ struct llm_graph_context {
     ggml_tensor * build_inp_mean() const;
     ggml_tensor * build_inp_cls() const;
     ggml_tensor * build_inp_s_copy() const;
-    ggml_tensor * build_inp_s_mask() const;
 
     ggml_tensor * build_inp_cross_embd() const;
     ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -606,18 +593,17 @@ struct llm_graph_context {
     // recurrent
     //
 
-    ggml_tensor * build_copy_mask_state(
+    ggml_tensor * build_recurrent_state(
              ggml_cgraph * gf,
              ggml_tensor * s,
              ggml_tensor * state_copy,
-             ggml_tensor * state_mask,
-                 int32_t   n_state,
-                 int32_t   n_seqs) const;
+                 int32_t   state_size,
+                 int32_t   n_seqs,
+                    bool   avoid_copies = false) const;
 
     ggml_tensor * build_rwkv_token_shift_load(
              ggml_cgraph * gf,
              ggml_tensor * state_copy,
-             ggml_tensor * state_mask,
       const llama_ubatch & ubatch,
                      int   il) const;
 
index f5c6dcd66ce9e8a519d7158a28b9cd5fcd29568b..f8cdd52808d7be521c71e9bbbc8f41fdf7128db6 100644 (file)
@@ -406,21 +406,12 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
 
     bool success = true;
 
-    // TODO: here we have to verify that all ubatches can fit in the cells
-    //       however, the current implementation is broken because it relies on s_copy() and s_mask() to update the cells
-    //         during the compute of each ubatch. to reproduce, uncomment the following loop and run:
-    //
-    //           $ llama-parallel -m ./mamba-130m/ggml-model-f16.gguf -np 5 -ns 8
-    //
-    //       recovery from failures when the batch does not fit in the KV cache will not work correctly until this is fixed
-    //
-    GGML_UNUSED(ubatches);
-    //for (const auto & ubatch : ubatches) {
-    //    if (!find_slot(ubatch)) {
-    //        success = false;
-    //        break;
-    //    }
-    //}
+    for (const auto & ubatch : ubatches) {
+        if (!find_slot(ubatch)) {
+            success = false;
+            break;
+        }
+    }
 
     // restore the original state
     cells = std::move(org_cells);
@@ -431,14 +422,13 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
 }
 
 bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
-    const uint32_t n_tokens = ubatch.n_tokens;
-    const uint32_t n_seqs   = ubatch.n_seqs;
+    const uint32_t n_seqs = ubatch.n_seqs;
 
     const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
 
     // if we have enough unused cells before the current head ->
     //   better to start searching from the beginning of the cache, hoping to fill it
-    if (head > used + 2*n_tokens) {
+    if (head > used + 2*n_seqs) {
         head = 0;
     }
 
@@ -534,16 +524,16 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
                 empty_cell.src = orig_cell.src;
                 orig_cell.seq_id.erase(seq_id);
                 empty_cell.seq_id.insert(seq_id); // will be overwritten
+                GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
             }
             seq_meta.tail = next_empty_cell;
             // find next empty cell
             if (s + 1 < n_seqs) {
-                next_empty_cell += 1;
                 for (uint32_t i = 0; i < size; ++i) {
+                    next_empty_cell += 1;
                     if (next_empty_cell >= size) { next_empty_cell -= size; }
                     kv_cell & cell = cells[next_empty_cell];
                     if (cell.is_empty()) { break; }
-                    next_empty_cell += 1;
                 }
             }
         }
@@ -553,8 +543,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
 
     // gather and re-order
     for (uint32_t s = 0; s < n_seqs; ++s) {
-        int32_t dst_id = s + min;
-        int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
+        const int32_t dst_id = s + min;
+        const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
         if (dst_id != src_id) {
             kv_cell & dst_cell = cells[dst_id];
             kv_cell & src_cell = cells[src_id];
@@ -563,12 +553,14 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
             std::swap(dst_cell.src, src_cell.src);
             std::swap(dst_cell.seq_id, src_cell.seq_id);
 
-            // swap tails (assuming they NEVER overlap)
-            for (const llama_seq_id seq_id : src_cell.seq_id) {
-                cells[seq_id].tail = src_id;
-            }
-            for (const llama_seq_id seq_id : dst_cell.seq_id) {
-                cells[seq_id].tail = dst_id;
+            // swap tails
+            for (uint32_t i = 0; i < size; ++i) {
+                int32_t & tail = cells[i].tail;
+                if (tail == src_id) {
+                    tail = dst_id;
+                } else if (tail == dst_id) {
+                    tail = src_id;
+                }
             }
         }
     }
@@ -576,7 +568,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
     // update the pos of the used seqs
     for (uint32_t s = 0; s < n_seqs; ++s) {
         const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
-        int32_t cell_id = s + min;
+        const int32_t cell_id = s + min;
         kv_cell & cell = cells[cell_id];
 
         if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
@@ -594,6 +586,38 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
         }
     }
 
+    // Find first cell without src refs, to use as the zero-ed state
+    {
+        // TODO: bake-in src refcounts in the cell metadata
+        std::vector<int32_t> refcounts(size, 0);
+        for (size_t i = 0; i < size; ++i) {
+            const int32_t src = cells[i].src;
+            if (src >= 0) {
+                refcounts[src] += 1;
+            }
+        }
+
+        rs_z = -1;
+        for (int i = min; i <= max; ++i) {
+            if (refcounts[i] == 0) {
+                rs_z = i;
+                break;
+            }
+        }
+
+        for (int i = min; i <= max; ++i) {
+            if (cells[i].src < 0) {
+                GGML_ASSERT(rs_z >= 0);
+                cells[i].src0 = rs_z;
+            } else {
+                // Stage the source ids for all used cells to allow correct seq_* behavior
+                // and still make these values available when setting the inputs
+                cells[i].src0 = cells[i].src;
+            }
+            cells[i].src = i; // avoid moving or clearing twice
+        }
+    }
+
     // allow getting the range of used cells, from head to head + n
     head = min;
     n    = max - min + 1;
@@ -605,47 +629,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
 }
 
 bool llama_kv_cache_recurrent::get_can_shift() const {
-    return false;
-}
-
-int32_t llama_kv_cache_recurrent::s_copy(int i) const {
-    const uint32_t cell_id = i + head;
-
-    //////////////////////////////////////////////
-    // TODO: this should not mutate the KV cache !
-    kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
-
-    // prevent out-of-bound sources
-    if (cell.src < 0 || (uint32_t) cell.src >= size) {
-        cell.src = cell_id;
-    }
-
-    int32_t res = cell.src;
-
-    // TODO: do not mutate the KV cache
-    // ensure copy only happens once
-    if (cell.src != (int32_t) cell_id) {
-        cell.src = cell_id;
-    }
-
-    return res;
-}
-
-float llama_kv_cache_recurrent::s_mask(int i) const {
-    const uint32_t cell_id = i + head;
-
-    //////////////////////////////////////////////
-    // TODO: this should not mutate the KV cache !
-    kv_cell & cell = const_cast<kv_cell &>(cells[cell_id]);
-
-    float res = (float) (cell.src >= 0);
-
-    // only clear once
-    if (cell.src < 0) {
-        cell.src = cell_id;
-    }
-
-    return res;
+    // shifting the pos is trivial for recurrent models
+    return true;
 }
 
 size_t llama_kv_cache_recurrent::total_size() const {
@@ -1111,6 +1096,10 @@ uint32_t llama_kv_cache_recurrent_state::get_head() const {
     return is_full ? 0 : kv->head;
 }
 
+int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
+    return is_full ? 0 : kv->rs_z;
+}
+
 uint32_t llama_kv_cache_recurrent_state::get_size() const {
     return kv->size;
 }
@@ -1124,9 +1113,5 @@ ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
 }
 
 int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
-    return kv->s_copy(i);
-}
-
-float llama_kv_cache_recurrent_state::s_mask(int i) const {
-    return kv->s_mask(i);
+    return  kv->cells[i + kv->head].src0;
 }
index d1da1225655fa4b18d0ae055df131ebff2c9d4cc..4b33bafd71cca510374ae55f3993100b61fe21c4 100644 (file)
@@ -57,10 +57,6 @@ public:
 
     bool get_can_shift() const override;
 
-    // TODO: temporary methods - they are not really const as they do const_cast<>, fix this
-    int32_t s_copy(int i) const;
-    float   s_mask(int i) const;
-
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
@@ -73,10 +69,14 @@ public:
     // computed before each graph build
     uint32_t n = 0;
 
+    // first zero-ed state
+    int32_t rs_z = -1;
+
     // TODO: optimize for recurrent state needs
     struct kv_cell {
         llama_pos pos  = -1;
-        int32_t   src  = -1; // used to copy states
+        int32_t   src  = -1; // used to know where states should be copied from
+        int32_t   src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
         int32_t   tail = -1;
 
         std::set<llama_seq_id> seq_id;
@@ -157,13 +157,13 @@ public:
 
     uint32_t get_n_kv() const;
     uint32_t get_head() const;
+    int32_t  get_rs_z() const;
     uint32_t get_size() const;
 
     ggml_tensor * get_k_l(int32_t il) const;
     ggml_tensor * get_v_l(int32_t il) const;
 
     int32_t s_copy(int i) const;
-    float   s_mask(int i) const;
 
 private:
     const llama_memory_status status;
index 3566d5fd4d72bdf5a7e8a968f18197e7743815ae..fe41b948043104b9bd8c946b9c94bfb7d283a145 100644 (file)
@@ -512,8 +512,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
         head_cur = 0;
     }
 
-    // otherwise, one cell per token.
-
     if (n_tokens > cells.size()) {
         LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
         return -1;
index f4a66390c79812a06a7999e4372333b494f0b673..c64bf9de939f4395335e01f9fd0d1a766ab3993c 100644 (file)
@@ -8857,7 +8857,6 @@ struct llm_build_mamba : public llm_graph_context {
         inpL = build_inp_embd(model.tok_embd);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         for (int il = 0; il < n_layer; ++il) {
             // norm
@@ -8866,8 +8865,7 @@ struct llm_build_mamba : public llm_graph_context {
                     LLM_NORM_RMS, il);
             cb(cur, "attn_norm", il);
 
-            //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il);
-            cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il);
+            cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
 
             if (il == n_layer - 1) {
                 // skip computing output for unused tokens
@@ -8908,7 +8906,6 @@ struct llm_build_mamba : public llm_graph_context {
              ggml_cgraph * gf,
              ggml_tensor * cur,
              ggml_tensor * state_copy,
-             ggml_tensor * state_mask,
       const llama_ubatch & ubatch,
                      int   il) const {
         const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -8935,12 +8932,12 @@ struct llm_build_mamba : public llm_graph_context {
         ggml_tensor * ssm_states_all  = kv_state->get_v_l(il);
 
         // (ab)using the KV cache to store the states
-        ggml_tensor * conv = build_copy_mask_state(
-                gf, conv_states_all, state_copy, state_mask,
+        ggml_tensor * conv = build_recurrent_state(
+                gf, conv_states_all, state_copy,
                 hparams.n_embd_k_s(), n_seqs);
         conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
-        ggml_tensor * ssm = build_copy_mask_state(
-                gf, ssm_states_all, state_copy, state_mask,
+        ggml_tensor * ssm = build_recurrent_state(
+                gf, ssm_states_all, state_copy,
                 hparams.n_embd_v_s(), n_seqs);
         ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
 
@@ -11656,7 +11653,6 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             ggml_tensor * cur,
             ggml_tensor * x_prev,
             ggml_tensor * state_copy,
-            ggml_tensor * state_mask,
             const llama_ubatch & ubatch,
             int   il) const {
         const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -11780,8 +11776,8 @@ struct llm_build_rwkv6_base : public llm_graph_context {
             k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
         }
 
-        ggml_tensor * wkv_state = build_copy_mask_state(
-                gf, kv_state->get_v_l(il), state_copy, state_mask,
+        ggml_tensor * wkv_state = build_recurrent_state(
+                gf, kv_state->get_v_l(il), state_copy,
                 hparams.n_embd_v_s(), n_seqs);
 
         ggml_tensor * wkv_output;
@@ -11837,7 +11833,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
         inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -11848,7 +11843,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, state_mask, ubatch, il
+                    gf, state_copy, ubatch, il
                     );
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
@@ -11864,7 +11859,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -11935,7 +11930,6 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
         inpL = build_inp_embd(model.tok_embd);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -11946,7 +11940,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, state_mask, ubatch, il
+                    gf, state_copy, ubatch, il
                     );
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
@@ -11959,7 +11953,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
                     1
                     );
 
-            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il);
+            cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@@ -12051,7 +12045,6 @@ struct llm_build_rwkv7_base : public llm_graph_context {
             ggml_tensor * cur,
             ggml_tensor * x_prev,
             ggml_tensor * state_copy,
-            ggml_tensor * state_mask,
             ggml_tensor *& first_layer_value,
             const llama_ubatch & ubatch,
             int   il) const {
@@ -12134,8 +12127,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
         v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
         a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
 
-        ggml_tensor * wkv_state = build_copy_mask_state(
-                gf, kv_state->get_v_l(il), state_copy, state_mask,
+        ggml_tensor * wkv_state = build_recurrent_state(
+                gf, kv_state->get_v_l(il), state_copy,
                 hparams.n_embd_v_s(), n_seqs);
 
         ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@@ -12193,7 +12186,6 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
         inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12204,7 +12196,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, state_mask, ubatch, il
+                    gf, state_copy, ubatch, il
                     );
 
             ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
@@ -12220,7 +12212,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
 
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
             cb(ffn_inp, "ffn_inp", il);
@@ -12287,7 +12279,6 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
         inpL = build_inp_embd(model.tok_embd);
 
         ggml_tensor * state_copy = build_inp_s_copy();
-        ggml_tensor * state_mask = build_inp_s_mask();
 
         const auto n_embd = hparams.n_embd;
         const auto n_seq_tokens = ubatch.n_seq_tokens;
@@ -12298,7 +12289,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
             inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
 
             ggml_tensor * token_shift = build_rwkv_token_shift_load(
-                    gf, state_copy, state_mask, ubatch, il
+                    gf, state_copy, ubatch, il
                     );
 
             ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
@@ -12311,7 +12302,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
                     1
                     );
 
-            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il);
+            cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
 
             token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
             ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));