]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
graph : reduce splits for recurrent and hybrid models (#14825)
authorcompilade <redacted>
Thu, 31 Jul 2025 05:02:46 +0000 (01:02 -0400)
committerGitHub <redacted>
Thu, 31 Jul 2025 05:02:46 +0000 (08:02 +0300)
* graph : avoid creating redundant s_copy views

* graph : comment the s_copy views

src/llama-graph.cpp
src/llama-graph.h

index 702192b79df6ece99d152eaac00fcb0a8f39f009..ee861bd7ec18cea63a0f5c1902b08a57326bd342 100644 (file)
@@ -1644,16 +1644,17 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
 
 ggml_tensor * llm_graph_context::build_rs(
         ggml_tensor * s,
-        ggml_tensor * state_copy,
+        ggml_tensor * state_copy_main,
+        ggml_tensor * state_copy_extra,
             int32_t   state_size,
             int32_t   n_seqs,
-           uint32_t   n_kv,
-           uint32_t   kv_head,
-           uint32_t   kv_size,
+           uint32_t   n_rs,
+           uint32_t   rs_head,
+           uint32_t   rs_size,
             int32_t   rs_zero,
         const llm_graph_get_rows_fn & get_state_rows) const {
 
-    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
+    ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
 
     // Clear a single state which will then be copied to the other cleared states.
     // Note that this is a no-op when the view is zero-sized.
@@ -1661,39 +1662,44 @@ ggml_tensor * llm_graph_context::build_rs(
     ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
 
     // copy states
-    // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
-    // {state_size, kv_size} -> {state_size, n_seqs}
-    ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
+    // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
+    // {state_size, rs_size} -> {state_size, n_seqs}
+    ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
     ggml_build_forward_expand(gf, output_states);
 
-    // copy extra states which won't be changed further (between n_seqs and n_kv)
-    ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
+    // copy extra states which won't be changed further (between n_seqs and n_rs)
+    ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
     ggml_build_forward_expand(gf,
         ggml_cpy(ctx0,
             states_extra,
-            ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
+            ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
 
     return output_states;
 }
 
 static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
            ggml_context * ctx0,
+     const llama_ubatch & ubatch,
     const llama_memory_recurrent_context * mctx_cur) {
 
     auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
 
-    const auto n_rs = mctx_cur->get_n_rs();
+    const int64_t n_rs   = mctx_cur->get_n_rs();
+    const int64_t n_seqs = ubatch.n_seqs;
 
     inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
     ggml_set_input(inp->s_copy);
 
+    inp->s_copy_main  = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
+    inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
+
     return inp;
 }
 
 llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
     const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
-    auto inp = build_rs_inp_impl(ctx0, mctx_cur);
+    auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
 
     return (llm_graph_input_rs *) res->add_input(std::move(inp));
 }
@@ -1706,7 +1712,9 @@ ggml_tensor * llm_graph_context::build_rs(
         const llm_graph_get_rows_fn & get_state_rows) const {
     const auto * kv_state = inp->mctx;
 
-    return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
+    return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
+                    kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
+                    get_state_rows);
 }
 
 ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1753,7 +1761,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
 llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
     const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
 
-    auto inp_rs   = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
+    auto inp_rs   = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
     auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
 
     auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
index 8eae4f5515e3ca43a33908abcd5826eae9d8a768..55a6b6f3e05b3decf44e7914b8dbf0ab39f06f0e 100644 (file)
@@ -214,7 +214,12 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
-    ggml_tensor * s_copy; // I32 [kv_size]
+    ggml_tensor * s_copy;  // I32 [n_rs]
+
+    // views of s_copy, computed once per graph
+    // and shared across layers which use build_rs
+    ggml_tensor * s_copy_main;   // I32 [n_seqs]
+    ggml_tensor * s_copy_extra;  // I32 [n_rs - n_seqs]
 
     const llama_memory_recurrent_context * mctx;
 };
@@ -730,7 +735,6 @@ struct llm_graph_context {
     // recurrent
     //
 
-    // TODO: avoid notion of "kv"
     // TODO: move this implementation to llama_memory_recurrent.
     //       this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
     //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
@@ -738,12 +742,13 @@ struct llm_graph_context {
     //         `llama_memory_recurrent`
     ggml_tensor * build_rs(
             ggml_tensor * s,
-            ggml_tensor * state_copy,
+            ggml_tensor * state_copy_main,
+            ggml_tensor * state_copy_extra,
                 int32_t   state_size,
                 int32_t   n_seqs,
-               uint32_t   n_kv,
-               uint32_t   kv_head,
-               uint32_t   kv_size,
+               uint32_t   n_rs,
+               uint32_t   rs_head,
+               uint32_t   rs_size,
                 int32_t   rs_zero,
             const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;