]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
models : dedup Kimi Linear delta net implementation (#19668)
authorymcki <redacted>
Thu, 19 Feb 2026 06:15:17 +0000 (14:15 +0800)
committerGitHub <redacted>
Thu, 19 Feb 2026 06:15:17 +0000 (08:15 +0200)
* models : add llm_build_delta_net_base

* cont : keep qwen35 and qwen35moe graphs intact

* cont : add comments [no ci]

* add kimi linear to delta-net-base

* removed unnecessary ggml_cont from g_exp_t

* removed ggml_cont from g_diff_exp_t. moved ggml_cont for o to kimi-linear.cpp

* removed unnecessary diag mask

* cont : simplify

* cont : avoid graph splits

* scale q after mul instead of beginning

* scale q after mul instead of beginning

* identical ppl

* cont : fix scale and decay mask

* minor : remove TODO

---------

Co-authored-by: Georgi Gerganov <redacted>
src/models/delta-net-base.cpp
src/models/kimi-linear.cpp
src/models/models.h
src/models/qwen3next.cpp

index 0cdf9c324bacd85a3c1017f6c62e2194d6ec4693..99f1fdd9538cad6d318aeef7094e169ad21c8a27 100644 (file)
@@ -27,6 +27,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
 
     const int64_t S_v = v->ne[0];
     const int64_t H_v = v->ne[1];
+    const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k);
 
     GGML_ASSERT(S_k == S_v);
     GGML_ASSERT(H_v % H_k == 0);
@@ -35,9 +36,10 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
     GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
 
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
-    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
+    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
+    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
 
     const float scale = 1.0f / sqrtf(S_k);
 
@@ -52,8 +54,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
     k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
     v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
-    g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [  1, n_tokens, H_v, n_seqs]
-    b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [  1, n_tokens, H_v, n_seqs]
+    g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs]
+    b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [  1, n_tokens, H_v, n_seqs]
 
     const int CS = CHUNK_SIZE;
 
@@ -78,33 +80,76 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     v   = ggml_reshape_4d(ctx0, v,   S_v, CS, n_chunks, H_v * n_seqs);
     v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
 
-    g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
-    b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
+    g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1,        CS, n_chunks, H_v * n_seqs);
 
-    // [CS, 1, n_chunks, H_v * n_seqs]
-    ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
+    // [CS, g_0, n_chunks, H_v * n_seqs]
+    // TODO: extend ggml_cumsum with axis parameter to avoid transpose
+    ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g)));
     cb(g_cs, "g_cs", il);
 
-    ggml_tensor * g_cs_i = g_cs;
-    ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
+    ggml_tensor * kb = nullptr;
+    ggml_tensor * kq = nullptr;
+    if (kda) {
+        const int64_t CHB = n_chunks * H_k * n_seqs;
 
-    g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
+        ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
+        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB]
 
-    // [CS, CS, n_chunks, H_v * n_seqs]
-    ggml_tensor * decay_mask;
-    decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
-    decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    cb(decay_mask, "decay_mask", il);
+        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
 
-    // [CS, CS, n_chunks, H_k * n_seqs]
-    ggml_tensor * kb;
-    kb = ggml_mul_mat(ctx0, k,  k_b);
-    kb = ggml_mul    (ctx0, kb, decay_mask);
+        // decay_mask [chunk_size,chunk_size,S_k,CHB]
+        ggml_tensor * decay_mask;
+        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
+        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        cb(decay_mask, "decay_mask", il);
+
+        // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
+        decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB);
+
+        ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS,  1, CHB);
+        ggml_tensor * k_j   = ggml_reshape_4d(ctx0, k,   S_k,  1, CS, CHB);
+        ggml_tensor * q_i   = ggml_reshape_4d(ctx0, q,   S_k, CS,  1, CHB);
+
+        ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i);
+        ggml_tensor * decay_q_i   = ggml_mul(ctx0, decay_mask, q_i);
+
+        // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
+        kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j);
+        kq = ggml_mul_mat(ctx0, decay_q_i,   k_j);
+
+        kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs)));
+        kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs)));
+    } else {
+        ggml_tensor * g_cs_i = g_cs;
+        ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
+
+        g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
+
+        // [CS, CS, n_chunks, H_v * n_seqs]
+        ggml_tensor * decay_mask;
+        decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
+        decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+        decay_mask = ggml_exp(ctx0, decay_mask);
+        cb(decay_mask, "decay_mask", il);
+
+        // [CS, CS, n_chunks, H_k * n_seqs]
+        kb = ggml_mul_mat(ctx0, k,  k_b);
+        kb = ggml_mul    (ctx0, kb, decay_mask);
+
+        // [CS, CS, n_chunks, H_k * n_seqs]
+        kq = ggml_mul_mat(ctx0, k, q);
+        kq = ggml_mul(ctx0, kq, decay_mask);
+    }
+
+    kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
+    cb(kq, "kq", il);
 
     // [CS, CS, n_chunks, H_k * n_seqs]
     ggml_tensor * attn;
     attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
+    cb(attn, "attn", il);
 
     ggml_tensor * identity;
     identity = ggml_view_1d(ctx0, attn, CS, 0);
@@ -115,6 +160,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     cb(lhs, "dnet_add_ch_lhs", il);
 
     attn = ggml_neg(ctx0, attn);
+    cb(attn, "attn_pre_solve", il);
 
     ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
     attn = ggml_add(ctx0, lin_solve, identity);
@@ -123,7 +169,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     // [S_v, CS, n_chunks, H_v * n_seqs]
     v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
 
-    // [CS, 1, n_chunks, H_v * n_seqs]
+    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
     ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
 
     k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
@@ -136,16 +182,10 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
     cb(k_cd, "k_cumdecay", il);
 
-    // [S_k, CS, n_chunks, H_k * n_seqs]
-    ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
+    // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp));
     ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
 
-    // [CS, CS, n_chunks, H_k * n_seqs]
-    ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-    kq = ggml_mul(ctx0, kq, decay_mask);
-    kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
-    cb(kq, "kq", il);
-
     // vectorized calculation of key_gdiff
     // improved from the chunked version:
     //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
@@ -156,8 +196,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
 
     // get last element in g_cumsum along CS dimension (ne0)
     // example: [[x, y, z, ..., last], ...] -> [[last], ...]
-    // [1, 1, n_chunks, H_v * n_seqs]
-    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
+    // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3],
             g_cs->nb[1],
             g_cs->nb[2],
             g_cs->nb[3],
@@ -167,16 +207,15 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     // TODO: remove this cont when CUDA supports non-cont unary ops
     g_last = ggml_cont(ctx0, g_last);
 
-    // [1, 1, n_chunks, H_v * n_seqs]
-    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
-    cb(g_last_exp, "g_last_exp", il);
+    // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last));
+    cb(g_last_exp_t, "g_last_exp_t", il);
 
-    // [CS, 1, n_chunks, H_v * n_seqs]
+    // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
     ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
     cb(g_diff, "g_diff", il);
 
-    ggml_tensor * g_diff_exp   = ggml_exp(ctx0, g_diff);
-    ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
+    ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff)));
 
     // [S_k, CS, n_chunks, H_v * n_seqs]
     ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
@@ -227,8 +266,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
         ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
 
         // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-        ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
-        s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
+        ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
+
+        s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
         s_t = ggml_add(ctx0, s_t, kgv);
         cb(s_t, "dnet_add_ch_state", il);
     }
@@ -241,9 +281,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
             ggml_row_size(v->type, S_v),
             ggml_row_size(v->type, S_v * CS * n_chunks),
             ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
-
     o = ggml_permute  (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
-    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
+    s = ggml_transpose(ctx0, s_t);
+    cb(s, "output_state", il);
 
     return {o, s};
 }
@@ -273,9 +313,10 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
     GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
 
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
-    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
+    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
+    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
 
     const float scale = 1.0f / sqrtf(S_k);
 
@@ -291,8 +332,10 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     cb(b, "b_in", il);
     cb(g, "g_in", il);
 
-    g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
-    b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
+    // GDA: [1,  1,  H_v, n_seqs]
+    // KDA: [1, S_k, H_v, n_seqs]
+    g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1,        1, H_v, n_seqs);
 
     // [S_v, S_v, H_v, n_seqs]
     g = ggml_exp(ctx0, g);
index 133834021d09b115e6c57dee84befb83c3a89846..8173d894ef2ef2dfda2aa2667b8ca38ac0522923 100644 (file)
@@ -3,8 +3,6 @@
 
 #include "llama-memory-recurrent.h"
 
-#define CHUNK_SIZE 64
-
 // Causal Conv1d function for Q,K,V
 // When qkv is 0, it is Q, 1 is K, 2 is V
 static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
@@ -67,7 +65,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t
 }
 
 llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
-    llm_build_mamba_base(params), model(model) {
+    llm_build_delta_net_base(params), model(model) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
@@ -86,17 +84,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
     // Output ids for selecting which tokens to output
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-    ggml_tensor * chunked_causal_mask =
-        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
-                    GGML_TRI_TYPE_LOWER);
-
-    ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
-    ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
-
-    ggml_build_forward_expand(gf, chunked_causal_mask);
-    ggml_build_forward_expand(gf, chunked_identity);
-    ggml_build_forward_expand(gf, chunked_diag_mask);
-
     // Kimi dimension constants
     const int64_t n_head = hparams.n_head();
     const int64_t head_dim = hparams.n_embd_head_kda;
@@ -177,12 +164,22 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
             ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
             ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
             state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
-            // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
+
+            const float eps_norm = hparams.f_norm_rms_eps;
+
+            Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm);
+            Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm);
+            beta = ggml_sigmoid(ctx0, beta);
+
+            beta = ggml_reshape_4d(ctx0, beta,        1, n_head, n_seq_tokens, n_seqs);
+            g1   = ggml_reshape_4d(ctx0, g1,   head_dim, n_head, n_seq_tokens, n_seqs);
+
+            // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
             std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
-                build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
-                build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
+                build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
+                build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il);
 
-            ggml_tensor * output = attn_out.first;
+            ggml_tensor * output = ggml_cont(ctx0, attn_out.first);
             ggml_tensor * new_state = attn_out.second;
             cb(output, "attn_output", il);
             cb(new_state, "new_state", il);
@@ -393,385 +390,3 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
 
     ggml_build_forward_expand(gf, cur);
 }
-
-/*
-    This is a ggml implementation of the naive_chunk_kda function of
-    https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
-*/
-std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunking(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * gk,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        ggml_tensor * causal_mask,
-        ggml_tensor * identity,
-        ggml_tensor * diag_mask,
-        int           il) {
-    GGML_ASSERT(ggml_is_contiguous(state));
-
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    // TODO: can this ever be false?
-    const bool use_qk_l2norm = true;
-
-    if (use_qk_l2norm) {
-        const float eps_norm = hparams.f_norm_rms_eps;
-
-        q = ggml_l2_norm(ctx0, q, eps_norm);
-        k = ggml_l2_norm(ctx0, k, eps_norm);
-    }
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(gk, "gk_in", il);
-
-    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
-    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
-    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-
-    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    cb(q, "q_perm", il);
-    cb(k, "k_perm", il);
-    cb(v, "v_perm", il);
-    cb(beta, "beta_perm", il);
-    cb(gk, "gk_perm", il);
-    cb(state, "state_in", il);
-
-    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
-
-    // Do padding
-    const int64_t chunk_size = CHUNK_SIZE;
-
-    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
-    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
-
-    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
-    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
-    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
-    gk = ggml_pad(ctx0, gk, 0, pad, 0, 0);
-    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
-
-    cb(q, "q_pad", il);
-    cb(k, "k_pad", il);
-    cb(v, "v_pad", il);
-    cb(beta, "beta_pad", il);
-    cb(gk, "gk_pad", il);
-
-    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
-    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
-
-    cb(v_beta, "v_beta", il);
-    cb(k_beta, "k_beta", il);
-
-    const int64_t HB = H_k * n_seqs;
-
-    q      = ggml_cont_4d(ctx0, q,      S_k, chunk_size, n_chunks, HB);
-    k      = ggml_cont_4d(ctx0, k,      S_k, chunk_size, n_chunks, HB);
-    k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB);
-    v      = ggml_cont_4d(ctx0, v,      S_v, chunk_size, n_chunks, HB);
-    v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB);
-
-    gk    = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB);
-    beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB);
-
-    // switch for cumsum
-    gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB);
-    cb(gk, "gk", il);
-    ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk);
-    cb(gk_cumsum, "gk_cumsum", il);
-
-/*
-    Compute Akk and Aqk loop together
-    Akk loop:
-    for i in range(BT):
-        k_i = k[..., i, :] # k_i [B,H,NT,S]
-        g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S]
-        A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i)
-    Aqk loop:
-    for j in range(BT):
-        k_j = k[:, :, i, j]
-        g_j = g[:, :, i, j:j+1, :]
-        A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
-*/
-    const int64_t CHB = n_chunks * H_k * n_seqs;
-    ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
-    ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB]
-
-    ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
-    // decay_mask [chunk_size,chunk_size,S_k,CHB]
-    ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
-    cb(decay_mask, "decay_mask", il);
-
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-    cb(decay_mask, "decay_masked", il);
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-
-    // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
-    decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
-
-    ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
-    ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
-    ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
-
-    ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
-    ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
-
-    // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
-    ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
-    ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
-    Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
-    Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));
-    cb(Akk, "Akk", il);
-    cb(Aqk, "Aqk", il);
-
-    Akk = ggml_mul(ctx0, Akk, beta);
-    Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
-    cb(Akk, "attn_pre_solve", il);
-
-    Aqk = ggml_mul(ctx0, Aqk, diag_mask);
-    Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
-    cb(Aqk, "Aqk_masked", il);
-
-    // for i in range(1, chunk_size):
-    //          row = attn[..., i, :i].clone()
-    //          sub = attn[..., :i, :i].clone()
-    //          attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
-    // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
-    //
-    // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
-    ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask);
-    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
-
-    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, Akk, true, true, false);
-    Akk                      = ggml_mul(ctx0, lin_solve, causal_mask);
-    Akk                      = ggml_add(ctx0, Akk, identity);
-
-    cb(Akk, "attn_solved", il);
-
-    // switch back for downstream
-    gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB);
-    ggml_tensor * gkexp      = ggml_exp(ctx0, gk_cumsum);
-    cb(gk_cumsum, "gk_cumsum", il);
-
-    // u = (A*beta[..., None, :]) @ v  aka U_[t]
-    ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk);
-
-    ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp);
-    cb(kbeta_gkexp, "kbeta_gkexp", il);
-
-    ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk);
-    cb(k_cumdecay, "k_cumdecay", il);
-
-    ggml_tensor * core_attn_out = nullptr;
-    ggml_tensor * new_state = ggml_dup(ctx0, state);
-
-    cb(new_state, "new_state", il);
-
-    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
-// extract one chunk worth of data
-        auto chunkify = [=](ggml_tensor * t) {
-                    return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
-                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
-        };
-        auto chunkify_A = [=](ggml_tensor * t) {
-                    return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3],
-                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
-        };
-
-
-// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B]
-        ggml_tensor * k_chunk = chunkify(k);
-        ggml_tensor * q_chunk = chunkify(q);
-        ggml_tensor * vb_chunk = chunkify(vb);
-
-// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B]
-        ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum);
-        ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
-        ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk);
-        ggml_tensor * Aqk_chunk = chunkify_A(Aqk);
-
-        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
-
-        // new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B]
-        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t]
-        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
-
-        // v_new = v_i - v_prime or U_[t] - W_[t]*S_[t]
-        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime);
-        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
-
-        // q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B]
-        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        // or Gamma_[t]*Q_]t] @ S
-        ggml_tensor * q_gk_exp   = ggml_mul(ctx0, q_chunk, gkexp_chunk);
-        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp);
-        attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q
-
-        // v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B]
-        // core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t])
-        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk_chunk);
-
-        // o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i
-        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
-
-        core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
-
-        ggml_tensor * gk_cum_last =
-            ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3],
-                                        gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3],
-                                        gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1)));
-
-        ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last)));
-
-        ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last));
-
-        ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff);
-
-        ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp);
-
-        // rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S)
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff)));
-
-        new_state = ggml_add(ctx0,
-            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)),
-            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
-    }
-
-    core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
-
-    // truncate padded tokens
-    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
-            S_v, n_tokens, H_v, n_seqs,
-            ggml_row_size(core_attn_out->type, S_v),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
-    output_tokens = ggml_cont(ctx0, output_tokens);
-    // permute back to (S_v, H_v, n_tokens, n_seqs)
-    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
-    output_tokens = ggml_cont(ctx0, output_tokens);
-
-    cb(new_state, "output_state", il);
-
-    return {output_tokens, new_state};
-}
-
-std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_autoregressive(
-    ggml_tensor * q,
-    ggml_tensor * k,
-    ggml_tensor * v,
-    ggml_tensor * gk,
-    ggml_tensor * beta,
-    ggml_tensor * state,
-    int il) {
-    GGML_ASSERT(ggml_is_contiguous(v));
-    GGML_ASSERT(ggml_is_contiguous(gk));
-
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(n_tokens == 1);
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
-
-    const float scale = 1.0f / sqrtf(S_v);
-
-    q    = ggml_scale(ctx0, q, scale);
-    beta = ggml_sigmoid(ctx0, beta);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
-    cb(gk, "gk_in", il);
-
-// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B]
-// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B]
-// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B]
-    gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs);
-    ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk));
-    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
-
-    // Apply exponential to gk_t
-    gk_t = ggml_exp(ctx0, gk_t);
-    // Apply the gated delta rule for the single timestep
-    // last_recurrent_state = last_recurrent_state * gk_t
-    // S = S * g_i[..., None].exp()
-    state = ggml_mul(ctx0, state, gk_t);
-
-    ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
-
-// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B]
-    k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
-    ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k);
-
-    // v_i - (k_i[..., None] * S).sum(-2)
-    v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
-    ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state);
-
-    // b_i[..., None] * k_i
-    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t);
-
-    // S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2))
-    // v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B]
-    state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
-
-    q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
-    state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
-    ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
-    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
-    cb(core_attn_out, "output_tokens", il);
-    cb(state, "new_state", il);
-
-    return {core_attn_out, state};
-}
-
index 920a8e5798f395ee9ceb615053a285c704ce88b5..6d11a6ceef8e96d0576eaca8e1ea52dc40d82cb5 100644 (file)
@@ -320,8 +320,7 @@ struct llm_build_jamba : public llm_build_mamba_base {
     llm_build_jamba(const llama_model & model, const llm_graph_params & params);
 };
 
-// TODO: derive llm_build_delta_net_base instead
-struct llm_build_kimi_linear : public llm_build_mamba_base {
+struct llm_build_kimi_linear : public llm_build_delta_net_base {
     llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
 
     std::pair<ggml_tensor *, ggml_tensor *> build_kda_autoregressive(
index 0fdf2d42c252f8cda4e31f50dfa85fd230aedb49..974120ea6f2b99b89f17cacd01398c2a97cc4a16 100644 (file)
@@ -306,8 +306,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
 
     ggml_tensor * beta = ggml_sigmoid(ctx0, b);
 
-    beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
-
     // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
     ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
 
@@ -318,6 +316,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
 
+    beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
+    gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
+
     // Get convolution states from cache
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);