]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
models : deduplicate delta-net graphs for Qwen family (#19597)
authorGeorgi Gerganov <redacted>
Mon, 16 Feb 2026 12:35:04 +0000 (14:35 +0200)
committerGitHub <redacted>
Mon, 16 Feb 2026 12:35:04 +0000 (14:35 +0200)
* models : add llm_build_delta_net_base

* cont : keep qwen35 and qwen35moe graphs intact

* cont : add comments

17 files changed:
src/CMakeLists.txt
src/models/delta-net-base.cpp [new file with mode: 0644]
src/models/falcon-h1.cpp
src/models/granite-hybrid.cpp
src/models/graph-context-mamba.cpp [deleted file]
src/models/jamba.cpp
src/models/kimi-linear.cpp
src/models/mamba-base.cpp [new file with mode: 0644]
src/models/mamba.cpp
src/models/models.h
src/models/nemotron-h.cpp
src/models/plamo2.cpp
src/models/qwen35.cpp
src/models/qwen35moe.cpp
src/models/qwen3next.cpp
src/models/rwkv6-base.cpp
src/models/rwkv7-base.cpp

index fdda05d3ea4a8ee14212528e72498b50dad7f05b..daf249422a113aebd30a2596c3c926671613263b 100644 (file)
@@ -57,13 +57,14 @@ add_library(llama
             models/deci.cpp
             models/deepseek.cpp
             models/deepseek2.cpp
+            models/delta-net-base.cpp
             models/dots1.cpp
             models/dream.cpp
             models/ernie4-5-moe.cpp
             models/ernie4-5.cpp
+            models/exaone-moe.cpp
             models/exaone.cpp
             models/exaone4.cpp
-            models/exaone-moe.cpp
             models/falcon-h1.cpp
             models/falcon.cpp
             models/gemma-embedding.cpp
@@ -91,10 +92,12 @@ add_library(llama
             models/llama-iswa.cpp
             models/llama.cpp
             models/maincoder.cpp
+            models/mamba-base.cpp
             models/mamba.cpp
             models/mimo2-iswa.cpp
             models/minicpm3.cpp
             models/minimax-m2.cpp
+            models/mistral3.cpp
             models/modern-bert.cpp
             models/mpt.cpp
             models/nemotron-h.cpp
@@ -118,12 +121,12 @@ add_library(llama
             models/qwen2moe.cpp
             models/qwen2vl.cpp
             models/qwen3.cpp
-            models/qwen3vl.cpp
-            models/qwen3vl-moe.cpp
-            models/qwen3moe.cpp
-            models/qwen3next.cpp
             models/qwen35.cpp
             models/qwen35moe.cpp
+            models/qwen3moe.cpp
+            models/qwen3next.cpp
+            models/qwen3vl-moe.cpp
+            models/qwen3vl.cpp
             models/refact.cpp
             models/rnd1.cpp
             models/rwkv6-base.cpp
@@ -142,8 +145,6 @@ add_library(llama
             models/t5-enc.cpp
             models/wavtokenizer-dec.cpp
             models/xverse.cpp
-            models/mistral3.cpp
-            models/graph-context-mamba.cpp
             )
 
 set_target_properties(llama PROPERTIES
diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp
new file mode 100644 (file)
index 0000000..0cdf9c3
--- /dev/null
@@ -0,0 +1,333 @@
+#include "models.h"
+
+#define CHUNK_SIZE 64
+
+// utility to get one slice from the third dimension
+// input dim:  [x, y, c, b]
+// output dim: [x, y, 1, b]
+static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
+    return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
+        t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
+}
+
+llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {}
+
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b,
+        ggml_tensor * s,
+        int           il) {
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
+
+    const float scale = 1.0f / sqrtf(S_k);
+
+    q = ggml_scale(ctx0, q, scale);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(b, "b_in", il);
+    cb(g, "g_in", il);
+
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
+    g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [  1, n_tokens, H_v, n_seqs]
+    b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [  1, n_tokens, H_v, n_seqs]
+
+    const int CS = CHUNK_SIZE;
+
+    const int pad = (CS - n_tokens % CS) % CS;
+    const int n_chunks = (n_tokens + pad) / CS;
+
+    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+    g = ggml_pad(ctx0, g, 0, pad, 0, 0);
+    b = ggml_pad(ctx0, b, 0, pad, 0, 0);
+
+    ggml_tensor * v_b = ggml_mul(ctx0, v, b);
+    ggml_tensor * k_b = ggml_mul(ctx0, k, b);
+
+    cb(v_b, "v_b", il);
+    cb(k_b, "k_b", il);
+
+    q   = ggml_reshape_4d(ctx0, q,   S_k, CS, n_chunks, H_k * n_seqs);
+    k   = ggml_reshape_4d(ctx0, k,   S_k, CS, n_chunks, H_k * n_seqs);
+    k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
+    v   = ggml_reshape_4d(ctx0, v,   S_v, CS, n_chunks, H_v * n_seqs);
+    v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
+
+    g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
+
+    // [CS, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
+    cb(g_cs, "g_cs", il);
+
+    ggml_tensor * g_cs_i = g_cs;
+    ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
+
+    g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
+
+    // [CS, CS, n_chunks, H_v * n_seqs]
+    ggml_tensor * decay_mask;
+    decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
+    decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+    decay_mask = ggml_exp(ctx0, decay_mask);
+    cb(decay_mask, "decay_mask", il);
+
+    // [CS, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * kb;
+    kb = ggml_mul_mat(ctx0, k,  k_b);
+    kb = ggml_mul    (ctx0, kb, decay_mask);
+
+    // [CS, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * attn;
+    attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
+
+    ggml_tensor * identity;
+    identity = ggml_view_1d(ctx0, attn, CS, 0);
+    identity = ggml_fill   (ctx0, identity, 1.0f);
+    identity = ggml_diag   (ctx0, identity);
+
+    ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
+    cb(lhs, "dnet_add_ch_lhs", il);
+
+    attn = ggml_neg(ctx0, attn);
+
+    ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
+    attn = ggml_add(ctx0, lin_solve, identity);
+    cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
+
+    // [S_v, CS, n_chunks, H_v * n_seqs]
+    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
+
+    // [CS, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
+
+    k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
+
+    // [CS, S_k, n_chunks, H_k * n_seqs]
+    ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
+    cb(kbg, "k_beta_g_exp", il);
+
+    // [S_k, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
+    cb(k_cd, "k_cumdecay", il);
+
+    // [S_k, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
+    ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
+
+    // [CS, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+    kq = ggml_mul(ctx0, kq, decay_mask);
+    kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
+    cb(kq, "kq", il);
+
+    // vectorized calculation of key_gdiff
+    // improved from the chunked version:
+    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
+    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
+    //   key_gdiff = key * g_diff.unsqueeze(-1)
+    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+
+    // get last element in g_cumsum along CS dimension (ne0)
+    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
+    // [1, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
+            g_cs->nb[1],
+            g_cs->nb[2],
+            g_cs->nb[3],
+            ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
+    cb(g_last, "g_last", il);
+
+    // TODO: remove this cont when CUDA supports non-cont unary ops
+    g_last = ggml_cont(ctx0, g_last);
+
+    // [1, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
+    cb(g_last_exp, "g_last_exp", il);
+
+    // [CS, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
+    cb(g_diff, "g_diff", il);
+
+    ggml_tensor * g_diff_exp   = ggml_exp(ctx0, g_diff);
+    ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
+
+    // [S_k, CS, n_chunks, H_v * n_seqs]
+    ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
+    cb(kg, "key_gdiff", il);
+
+    // [CS, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
+    cb(kg_t, "key_gdiff_t", il);
+
+    ggml_tensor * s_t = ggml_transpose(ctx0, s);
+    s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
+    cb(s_t, "dnet_add_ch_state", il);
+
+    // [CS, S_v, n_chunks, H_v * n_seqs]
+    ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+        ggml_tensor * ch_k_cd    = get_slice_2d(ctx0, k_cd,    chunk); // [S_k,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_v_t     = get_slice_2d(ctx0, v_t,     chunk); // [ CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * ch_kq      = get_slice_2d(ctx0, kq,      chunk); // [ CS,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_kg_t    = get_slice_2d(ctx0, kg_t,    chunk); // [ CS, S_k, 1, H_v * n_seqs]
+
+        // [CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
+        cb(v_t_p, "v_prime", il);
+
+        // [CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
+        cb(v_t_new, "v_t_new", il);
+
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
+        cb(v_attn, "v_attn", il);
+
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
+        cb(attn_inter, "attn_inter", il);
+
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
+        cb(o_ch, "dnet_add_ch_attn_out", il);
+
+        v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
+
+        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+        // TODO: head broadcast might not work here - probably will need a transpose
+        ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
+
+        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+        ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
+        s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
+        s_t = ggml_add(ctx0, s_t, kgv);
+        cb(s_t, "dnet_add_ch_state", il);
+    }
+
+    s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
+
+    // truncate padded tokens
+    ggml_tensor * o = ggml_view_4d(ctx0, v,
+            S_v, n_tokens, H_v, n_seqs,
+            ggml_row_size(v->type, S_v),
+            ggml_row_size(v->type, S_v * CS * n_chunks),
+            ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
+
+    o = ggml_permute  (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
+
+    return {o, s};
+}
+
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b, // beta
+        ggml_tensor * s, // state
+        int           il) {
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(n_tokens == 1);
+
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
+
+    const float scale = 1.0f / sqrtf(S_k);
+
+    q = ggml_scale(ctx0, q, scale);
+
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(b, "b_in", il);
+    cb(g, "g_in", il);
+
+    g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
+
+    // [S_v, S_v, H_v, n_seqs]
+    g = ggml_exp(ctx0, g);
+    s = ggml_mul(ctx0, s, g);
+
+    ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
+
+    // [1, S_v, H_v, n_seqs]
+    ggml_tensor * sk;
+    sk = ggml_mul     (ctx0, s_t, k);
+    sk = ggml_sum_rows(ctx0, sk);
+
+    // [S_v, 1, H_v, n_seqs]
+    ggml_tensor * d;
+    d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
+    d = ggml_mul(ctx0, d, b);
+
+    // [1, S_v, H_v, n_seqs]
+    ggml_tensor * d_t;
+    d_t = ggml_transpose(ctx0, d);
+
+    // [S_v, S_v, H_v, n_seqs]
+    ggml_tensor * kd;
+    k  = ggml_repeat(ctx0, k, s);
+    kd = ggml_mul   (ctx0, k, d_t);
+
+    s_t = ggml_add(ctx0, s_t, kd);
+
+    cb(s_t, "dnet_add_ar_state", il);
+
+    ggml_tensor * s_q = ggml_mul     (ctx0, s_t, q);
+    ggml_tensor * o   = ggml_sum_rows(ctx0, s_q);
+
+    o = ggml_permute  (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
+
+    return {o, s};
+}
index b641a094079421215385a85a6a298d59874d304d..785a7e5e6629cb1190ee8d2bb5e2901550112fb7 100644 (file)
@@ -1,9 +1,7 @@
 #include "models.h"
 
-
-
 llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
+    llm_build_mamba_base(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     ggml_tensor * cur;
index f6ca4c17a214a289386aa3c125ef411048773f76..726ecdcca776e5aa8d45796f83cbd160ef43f949 100644 (file)
@@ -2,7 +2,7 @@
 
 
 llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
+    llm_build_mamba_base(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
diff --git a/src/models/graph-context-mamba.cpp b/src/models/graph-context-mamba.cpp
deleted file mode 100644 (file)
index b9a363b..0000000
+++ /dev/null
@@ -1,283 +0,0 @@
-#include "models.h"
-
-llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
-
-ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp,
-                                                         ggml_tensor *        cur,
-                                                         const llama_model &  model,
-                                                         const llama_ubatch & ubatch,
-                                                         int                  il) {
-    const auto * mctx_cur = inp->mctx;
-
-    const auto kv_head = mctx_cur->get_head();
-
-    const auto & layer = model.layers[il];
-
-    const int64_t d_conv         = hparams.ssm_d_conv;
-    const int64_t d_inner        = hparams.ssm_d_inner;
-    const int64_t d_state        = hparams.ssm_d_state;
-    const int64_t dt_rank        = hparams.ssm_dt_rank;
-    const int64_t n_head         = d_inner;
-    const int64_t head_dim       = 1;
-    const int64_t n_seqs         = ubatch.n_seqs;
-    // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
-    const bool    ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
-
-    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-
-    GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(ubatch.equal_seqs());
-    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
-
-    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
-    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
-
-    ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
-    conv               = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
-
-    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
-    cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
-
-    // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
-    ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur);
-    // split the above in two
-    // => {d_inner, n_seq_tokens, n_seqs}
-    ggml_tensor * x  = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
-    ggml_tensor * z =
-        ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner * ggml_element_size(xz));
-
-    // conv
-    {
-        // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
-        ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
-
-        // copy last (d_conv - 1) columns back into the state cache
-        ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2],
-                                               n_seq_tokens * (conv_x->nb[0]));
-
-        ggml_build_forward_expand(
-            gf, ggml_cpy(ctx0, last_conv,
-                         ggml_view_1d(ctx0, conv_states_all, (d_conv - 1) * (d_inner) * (n_seqs),
-                                      kv_head * (d_conv - 1) * (d_inner) *ggml_element_size(conv_states_all))));
-
-        // 1D convolution
-        // The equivalent is to make a self-overlapping view of conv_x
-        // over d_conv columns at each stride in the 3rd dimension,
-        // then element-wise multiply that with the conv1d weight,
-        // then sum the elements of each row,
-        // (the last two steps are a dot product over rows (also doable with mul_mat))
-        // then permute away the ne[0] dimension,
-        // and then you're left with the resulting x tensor.
-        // For simultaneous sequences, all sequences need to have the same length.
-        x = ggml_ssm_conv(ctx0, conv_x, layer.ssm_conv1d);
-
-        // bias
-        x = ggml_add(ctx0, x, layer.ssm_conv1d_b);
-
-        x = ggml_silu(ctx0, x);
-    }
-
-    // ssm
-    {
-        // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
-        ggml_tensor * x_db = build_lora_mm(layer.ssm_x, x);
-        // split
-        ggml_tensor * dt   = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
-        ggml_tensor * B =
-            ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
-                         x_db->nb[2], ggml_element_size(x_db) * dt_rank);
-        ggml_tensor * C =
-            ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
-                         x_db->nb[2], ggml_element_size(x_db) * (dt_rank + d_state));
-
-        // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers
-        if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) {
-            dt = build_norm(dt, layer.ssm_dt_norm, NULL, LLM_NORM_RMS, il);
-            B  = build_norm(B, layer.ssm_b_norm, NULL, LLM_NORM_RMS, il);
-            C  = build_norm(C, layer.ssm_c_norm, NULL, LLM_NORM_RMS, il);
-        }
-
-        // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
-        dt = build_lora_mm(layer.ssm_dt, dt);
-        dt = ggml_add(ctx0, dt, layer.ssm_dt_b);
-
-        cur = x;
-        x   = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
-
-        ggml_tensor * A = layer.ssm_a;
-
-        // use the states and the indices provided by build_recurrent_state
-        // (this is necessary in order to properly use the states before they are overwritten,
-        //  while avoiding to make unnecessary copies of the states)
-        auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
-            ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
-
-            // Custom operator to optimize the parallel associative scan
-            // as described in the Annex D of the Mamba paper.
-            // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
-            return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
-        };
-
-        ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
-
-        // store last states
-        ggml_build_forward_expand(
-            gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, x->nb[3] * x->ne[3]),
-                         ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
-                                      kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
-
-        ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
-
-        // TODO: skip computing output earlier for unused tokens
-
-        y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
-        y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
-
-        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
-        cur = build_lora_mm(layer.ssm_out, y);
-    }
-
-    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
-    cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
-
-    return cur;
-}
-
-ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp,
-                                                          ggml_tensor *        cur,
-                                                          const llama_model &  model,
-                                                          const llama_ubatch & ubatch,
-                                                          int                  il) const {
-    const auto * mctx_cur = inp->mctx;
-
-    const auto kv_head = mctx_cur->get_head();
-
-    const int64_t d_conv   = hparams.ssm_d_conv;
-    const int64_t d_inner  = hparams.ssm_d_inner;
-    const int64_t d_state  = hparams.ssm_d_state;
-    const int64_t n_head   = hparams.ssm_dt_rank;
-    const int64_t head_dim = d_inner / n_head;
-    const int64_t n_group  = hparams.ssm_n_group;
-    const int64_t n_seqs   = ubatch.n_seqs;
-
-    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-
-    GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(ubatch.equal_seqs());
-    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
-
-    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
-    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
-
-    ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
-    conv               = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs);
-
-    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
-    cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
-
-    // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
-
-    // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
-    ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
-
-    // split the above in three
-    ggml_tensor * z   = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0],
-                                     zxBCdt->nb[1], zxBCdt->nb[2], 0);
-    ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2 * n_group * d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1],
-                                     zxBCdt->nb[2], d_inner * ggml_element_size(zxBCdt));
-    ggml_tensor * dt  = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2],
-                                     (2 * d_inner + 2 * n_group * d_state) * ggml_element_size(zxBCdt));
-
-    // conv
-    {
-        // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
-        ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
-
-        // copy last (d_conv - 1) columns back into the state cache
-        ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs,
-                                               conv_x->nb[1], conv_x->nb[2], n_seq_tokens * (conv_x->nb[0]));
-
-        ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv,
-                                               ggml_view_1d(ctx0, conv_states_all,
-                                                            (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs),
-                                                            kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) *
-                                                                ggml_element_size(conv_states_all))));
-
-        // 1D convolution
-        // The equivalent is to make a self-overlapping view of conv_x
-        // over d_conv columns at each stride in the 3rd dimension,
-        // then element-wise multiply that with the conv1d weight,
-        // then sum the elements of each row,
-        // (the last two steps are a dot product over rows (also doable with mul_mat))
-        // then permute away the ne[0] dimension,
-        // and then you're left with the resulting x tensor.
-        // For simultaneous sequences, all sequences need to have the same length.
-        xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
-
-        // bias
-        xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
-
-        xBC = ggml_silu(ctx0, xBC);
-    }
-
-    // ssm
-    {
-        // These correspond to V K Q in SSM/attention duality
-        ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0],
-                                       xBC->nb[1], xBC->nb[2], 0);
-        ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
-                                       xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC));
-        ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
-                                       xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC));
-
-        // {n_head, n_seq_tokens, n_seqs}
-        dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
-
-        ggml_tensor * A = model.layers[il].ssm_a;
-
-        // use the states and the indices provided by build_recurrent_state
-        // (this is necessary in order to properly use the states before they are overwritten,
-        //  while avoiding to make unnecessary copies of the states)
-        auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
-            ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
-
-            // TODO: use semistructured matrices to implement state-space duality
-            // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
-            return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
-        };
-
-        ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
-
-        // store last states
-        ggml_build_forward_expand(
-            gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]),
-                         ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
-                                      kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
-
-        ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1],
-                                       n_seq_tokens * n_head * x->nb[1], 0);
-
-        // TODO: skip computing output earlier for unused tokens
-
-        y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
-        cb(y, "mamba2_y_add_d", il);
-        y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
-
-        // grouped RMS norm
-        if (model.layers[il].ssm_norm) {
-            y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
-            y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
-        }
-
-        y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
-
-        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
-        cur = build_lora_mm(model.layers[il].ssm_out, y);
-    }
-
-    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
-    cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
-    cb(cur, "mamba_out", il);
-
-    return cur;
-}
index a0187772ccbe8ec58b75ded62b6aa39ba3dd1900..ceab58174079be9810f67d74d76ba7f814bed8c2 100644 (file)
@@ -1,6 +1,6 @@
 #include "models.h"
 
-llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
+llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     ggml_tensor * cur;
index 942844d071f32720f3bb4f81111e520c7c0825a6..133834021d09b115e6c57dee84befb83c3a89846 100644 (file)
@@ -1,6 +1,8 @@
 #include "models.h"
 #include "ggml.h"
 
+#include "llama-memory-recurrent.h"
+
 #define CHUNK_SIZE 64
 
 // Causal Conv1d function for Q,K,V
@@ -65,7 +67,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t
 }
 
 llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params), model(model) {
+    llm_build_mamba_base(params), model(model) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
diff --git a/src/models/mamba-base.cpp b/src/models/mamba-base.cpp
new file mode 100644 (file)
index 0000000..aaac948
--- /dev/null
@@ -0,0 +1,285 @@
+#include "models.h"
+
+#include "llama-memory-recurrent.h"
+
+llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {}
+
+ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp,
+                                                         ggml_tensor *        cur,
+                                                         const llama_model &  model,
+                                                         const llama_ubatch & ubatch,
+                                                         int                  il) {
+    const auto * mctx_cur = inp->mctx;
+
+    const auto kv_head = mctx_cur->get_head();
+
+    const auto & layer = model.layers[il];
+
+    const int64_t d_conv         = hparams.ssm_d_conv;
+    const int64_t d_inner        = hparams.ssm_d_inner;
+    const int64_t d_state        = hparams.ssm_d_state;
+    const int64_t dt_rank        = hparams.ssm_dt_rank;
+    const int64_t n_head         = d_inner;
+    const int64_t head_dim       = 1;
+    const int64_t n_seqs         = ubatch.n_seqs;
+    // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
+    const bool    ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
+
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    conv               = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
+
+    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+    cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+    // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur);
+    // split the above in two
+    // => {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * x  = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
+    ggml_tensor * z =
+        ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner * ggml_element_size(xz));
+
+    // conv
+    {
+        // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
+        ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
+
+        // copy last (d_conv - 1) columns back into the state cache
+        ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2],
+                                               n_seq_tokens * (conv_x->nb[0]));
+
+        ggml_build_forward_expand(
+            gf, ggml_cpy(ctx0, last_conv,
+                         ggml_view_1d(ctx0, conv_states_all, (d_conv - 1) * (d_inner) * (n_seqs),
+                                      kv_head * (d_conv - 1) * (d_inner) *ggml_element_size(conv_states_all))));
+
+        // 1D convolution
+        // The equivalent is to make a self-overlapping view of conv_x
+        // over d_conv columns at each stride in the 3rd dimension,
+        // then element-wise multiply that with the conv1d weight,
+        // then sum the elements of each row,
+        // (the last two steps are a dot product over rows (also doable with mul_mat))
+        // then permute away the ne[0] dimension,
+        // and then you're left with the resulting x tensor.
+        // For simultaneous sequences, all sequences need to have the same length.
+        x = ggml_ssm_conv(ctx0, conv_x, layer.ssm_conv1d);
+
+        // bias
+        x = ggml_add(ctx0, x, layer.ssm_conv1d_b);
+
+        x = ggml_silu(ctx0, x);
+    }
+
+    // ssm
+    {
+        // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
+        ggml_tensor * x_db = build_lora_mm(layer.ssm_x, x);
+        // split
+        ggml_tensor * dt   = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
+        ggml_tensor * B =
+            ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
+                         x_db->nb[2], ggml_element_size(x_db) * dt_rank);
+        ggml_tensor * C =
+            ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
+                         x_db->nb[2], ggml_element_size(x_db) * (dt_rank + d_state));
+
+        // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers
+        if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) {
+            dt = build_norm(dt, layer.ssm_dt_norm, NULL, LLM_NORM_RMS, il);
+            B  = build_norm(B, layer.ssm_b_norm, NULL, LLM_NORM_RMS, il);
+            C  = build_norm(C, layer.ssm_c_norm, NULL, LLM_NORM_RMS, il);
+        }
+
+        // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
+        dt = build_lora_mm(layer.ssm_dt, dt);
+        dt = ggml_add(ctx0, dt, layer.ssm_dt_b);
+
+        cur = x;
+        x   = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
+
+        ggml_tensor * A = layer.ssm_a;
+
+        // use the states and the indices provided by build_recurrent_state
+        // (this is necessary in order to properly use the states before they are overwritten,
+        //  while avoiding to make unnecessary copies of the states)
+        auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
+            ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
+
+            // Custom operator to optimize the parallel associative scan
+            // as described in the Annex D of the Mamba paper.
+            // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+            return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
+        };
+
+        ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+
+        // store last states
+        ggml_build_forward_expand(
+            gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, x->nb[3] * x->ne[3]),
+                         ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
+                                      kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
+
+        ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
+
+        // TODO: skip computing output earlier for unused tokens
+
+        y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
+        y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
+
+        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+        cur = build_lora_mm(layer.ssm_out, y);
+    }
+
+    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+    cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+
+    return cur;
+}
+
+ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
+                                                          ggml_tensor *        cur,
+                                                          const llama_model &  model,
+                                                          const llama_ubatch & ubatch,
+                                                          int                  il) const {
+    const auto * mctx_cur = inp->mctx;
+
+    const auto kv_head = mctx_cur->get_head();
+
+    const int64_t d_conv   = hparams.ssm_d_conv;
+    const int64_t d_inner  = hparams.ssm_d_inner;
+    const int64_t d_state  = hparams.ssm_d_state;
+    const int64_t n_head   = hparams.ssm_dt_rank;
+    const int64_t head_dim = d_inner / n_head;
+    const int64_t n_group  = hparams.ssm_n_group;
+    const int64_t n_seqs   = ubatch.n_seqs;
+
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+    ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
+
+    ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+    conv               = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs);
+
+    // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+    cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+    // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
+
+    // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
+    ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
+
+    // split the above in three
+    ggml_tensor * z   = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0],
+                                     zxBCdt->nb[1], zxBCdt->nb[2], 0);
+    ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2 * n_group * d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1],
+                                     zxBCdt->nb[2], d_inner * ggml_element_size(zxBCdt));
+    ggml_tensor * dt  = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2],
+                                     (2 * d_inner + 2 * n_group * d_state) * ggml_element_size(zxBCdt));
+
+    // conv
+    {
+        // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
+        ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
+
+        // copy last (d_conv - 1) columns back into the state cache
+        ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs,
+                                               conv_x->nb[1], conv_x->nb[2], n_seq_tokens * (conv_x->nb[0]));
+
+        ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv,
+                                               ggml_view_1d(ctx0, conv_states_all,
+                                                            (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs),
+                                                            kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) *
+                                                                ggml_element_size(conv_states_all))));
+
+        // 1D convolution
+        // The equivalent is to make a self-overlapping view of conv_x
+        // over d_conv columns at each stride in the 3rd dimension,
+        // then element-wise multiply that with the conv1d weight,
+        // then sum the elements of each row,
+        // (the last two steps are a dot product over rows (also doable with mul_mat))
+        // then permute away the ne[0] dimension,
+        // and then you're left with the resulting x tensor.
+        // For simultaneous sequences, all sequences need to have the same length.
+        xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
+
+        // bias
+        xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
+
+        xBC = ggml_silu(ctx0, xBC);
+    }
+
+    // ssm
+    {
+        // These correspond to V K Q in SSM/attention duality
+        ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0],
+                                       xBC->nb[1], xBC->nb[2], 0);
+        ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
+                                       xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC));
+        ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
+                                       xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC));
+
+        // {n_head, n_seq_tokens, n_seqs}
+        dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
+
+        ggml_tensor * A = model.layers[il].ssm_a;
+
+        // use the states and the indices provided by build_recurrent_state
+        // (this is necessary in order to properly use the states before they are overwritten,
+        //  while avoiding to make unnecessary copies of the states)
+        auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
+            ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
+
+            // TODO: use semistructured matrices to implement state-space duality
+            // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+            return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
+        };
+
+        ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+
+        // store last states
+        ggml_build_forward_expand(
+            gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]),
+                         ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
+                                      kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
+
+        ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1],
+                                       n_seq_tokens * n_head * x->nb[1], 0);
+
+        // TODO: skip computing output earlier for unused tokens
+
+        y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
+        cb(y, "mamba2_y_add_d", il);
+        y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
+
+        // grouped RMS norm
+        if (model.layers[il].ssm_norm) {
+            y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
+            y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+        }
+
+        y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
+
+        // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+        cur = build_lora_mm(model.layers[il].ssm_out, y);
+    }
+
+    // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+    cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+    cb(cur, "mamba_out", il);
+
+    return cur;
+}
index 46819613c2d999e3b4fac10efdee82dd00f8da3f..55fd2e055c494aa0a5a02d35ea2e430fa2602c04 100644 (file)
@@ -1,7 +1,6 @@
 #include "models.h"
 
-
-llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
+llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
index ec6f80e5265949753e089ac0a1d7d2bdb20df14d..920a8e5798f395ee9ceb615053a285c704ce88b5 100644 (file)
@@ -1,23 +1,51 @@
 #pragma once
 
-#include "../llama-model.h"
-#include "../llama-graph.h"
+#include "llama-model.h"
+#include "llama-graph.h"
 
-// TODO: remove in follow-up PR - move to .cpp files
-#include "../llama-memory-recurrent.h"
+// note: almost all graphs require atleast sqrtf, so include cmath globally
 #include <cmath>
 
-struct llm_graph_context_mamba : public llm_graph_context {
-    llm_graph_context_mamba(const llm_graph_params & params);
+//
+// base classes
+//
 
-    virtual ~llm_graph_context_mamba() = default;
+struct llm_build_mamba_base : public llm_graph_context {
+    llm_build_mamba_base(const llm_graph_params & params);
+
+    virtual ~llm_build_mamba_base() = default;
 
     ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
     ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const;
 
 };
 
-// Base class for RWKV-related models
+struct llm_build_delta_net_base : public llm_graph_context {
+    llm_build_delta_net_base(const llm_graph_params & params);
+
+    virtual ~llm_build_delta_net_base() = default;
+
+    // returns pair of output and new state
+    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                        int   il);
+
+    // returns pair of output and new state
+    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                int           il);
+};
+
 struct llm_build_rwkv6_base : public llm_graph_context {
     const llama_model & model;
 
@@ -58,6 +86,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
                                        int                  il) const;
 };
 
+//
+// models
+//
+
 struct llm_build_afmoe : public llm_graph_context {
     llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
 };
@@ -175,7 +207,7 @@ struct llm_build_falcon : public llm_graph_context {
     llm_build_falcon(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_falcon_h1 : public llm_graph_context_mamba {
+struct llm_build_falcon_h1 : public llm_build_mamba_base {
     llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params);
 };
 
@@ -253,7 +285,7 @@ private:
         const int                 il);
 };
 
-struct llm_build_granite_hybrid : public llm_graph_context_mamba {
+struct llm_build_granite_hybrid : public llm_build_mamba_base {
     llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params);
     ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il);
     ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn,
@@ -284,11 +316,12 @@ struct llm_build_jais : public llm_graph_context {
     llm_build_jais(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_jamba : public llm_graph_context_mamba {
+struct llm_build_jamba : public llm_build_mamba_base {
     llm_build_jamba(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_kimi_linear : public llm_graph_context_mamba {
+// TODO: derive llm_build_delta_net_base instead
+struct llm_build_kimi_linear : public llm_build_mamba_base {
     llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
 
     std::pair<ggml_tensor *, ggml_tensor *> build_kda_autoregressive(
@@ -347,7 +380,7 @@ struct llm_build_maincoder : public llm_graph_context {
     llm_build_maincoder(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_mamba : public llm_graph_context_mamba {
+struct llm_build_mamba : public llm_build_mamba_base {
     llm_build_mamba(const llama_model & model, const llm_graph_params & params);
 };
 
@@ -379,11 +412,11 @@ struct llm_build_nemotron : public llm_graph_context {
     llm_build_nemotron(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_nemotron_h : public llm_graph_context_mamba {
+struct llm_build_nemotron_h : public llm_build_mamba_base {
     llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params);
-    ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il);
+    ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il);
     ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn,
-        const llama_model & model, const int64_t n_embd_head, const int il);
+        const llama_model & model, int64_t n_embd_head, int il);
 };
 
 struct llm_build_neo_bert : public llm_graph_context {
@@ -428,7 +461,7 @@ struct llm_build_phi3 : public llm_graph_context {
     llm_build_phi3(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_plamo2 : public llm_graph_context_mamba {
+struct llm_build_plamo2 : public llm_build_mamba_base {
     llm_build_plamo2(const llama_model & model, const llm_graph_params & params);
     private:
         ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
@@ -477,7 +510,7 @@ struct llm_build_qwen3vlmoe : public llm_graph_context {
     llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
 };
 
-struct llm_build_qwen3next : public llm_graph_context_mamba {
+struct llm_build_qwen3next : public llm_build_delta_net_base {
     llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
 private:
     ggml_tensor * build_layer_attn(
@@ -495,26 +528,6 @@ private:
                 ggml_tensor * cur,
                         int   il);
 
-    // returns pair of output and new state
-    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
-                ggml_tensor * q,
-                ggml_tensor * k,
-                ggml_tensor * v,
-                ggml_tensor * g,
-                ggml_tensor * beta,
-                ggml_tensor * state,
-                        int   il);
-
-    // returns pair of output and new state
-    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
-                ggml_tensor * q,
-                ggml_tensor * k,
-                ggml_tensor * v,
-                ggml_tensor * g,
-                ggml_tensor * beta,
-                ggml_tensor * state,
-                int           il);
-
     ggml_tensor * build_norm_gated(
                 ggml_tensor * input,
                 ggml_tensor * weights,
@@ -529,7 +542,8 @@ private:
     const llama_model & model;
 };
 
-struct llm_build_qwen35 : public llm_graph_context_mamba {
+// TODO: derive llm_build_delta_net_base instead
+struct llm_build_qwen35 : public llm_graph_context {
     llm_build_qwen35(const llama_model & model, const llm_graph_params & params);
 private:
     ggml_tensor * build_layer_attn(
@@ -547,6 +561,7 @@ private:
                 ggml_tensor * diag_mask,
                         int   il);
 
+
     ggml_tensor * build_layer_ffn(
                 ggml_tensor * cur,
                         int   il);
@@ -588,7 +603,8 @@ private:
     const llama_model & model;
 };
 
-struct llm_build_qwen35moe : public llm_graph_context_mamba {
+// TODO: derive llm_build_delta_net_base instead
+struct llm_build_qwen35moe : public llm_graph_context {
     llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params);
 private:
     ggml_tensor * build_layer_attn(
index 079c730ac292ed9596c79d6aec38ab57d54df45b..d61d62a8c962581bea21c7c345eb26830a631d7c 100644 (file)
@@ -1,9 +1,7 @@
 #include "models.h"
 
-
-
 llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
+    llm_build_mamba_base(params) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
@@ -65,8 +63,8 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_
 ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *             cur,
                                                           llm_graph_input_attn_kv * inp_attn,
                                                           const llama_model &       model,
-                                                          const int64_t             n_embd_head,
-                                                          const int                 il) {
+                                                                int64_t             n_embd_head,
+                                                                int                 il) {
     // compute Q and K
     ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
     cb(Qcur, "Qcur", il);
@@ -106,7 +104,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor *
     return cur;
 }
 
-ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) {
+ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) {
     if (model.layers[il].ffn_gate_inp == nullptr) {
         cur = build_ffn(cur,
                 model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
index 31115a08f95e446e71ff8e704f0eea1561ee271f..3af236843bb30d2aa2a4861d0946b5ddd35386d9 100644 (file)
@@ -1,7 +1,9 @@
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params) {
+    llm_build_mamba_base(params) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
index 592c170457bf8250b6374b9012378794047c1c19..94c68dbb268ab783c516046631afe4323d539a7d 100644 (file)
@@ -1,10 +1,11 @@
-#include "ggml.h"
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 #define CHUNK_SIZE 64
 
 llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params), model(model) {
+    llm_graph_context(params), model(model) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
index 0db8f825c67bd920f74e49e935e68fc430c7f743..93da7ea628cdf20b5f7349574b67e146037fe9cd 100644 (file)
@@ -1,10 +1,11 @@
-#include "ggml.h"
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 #define CHUNK_SIZE 64
 
 llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params), model(model) {
+    llm_graph_context(params), model(model) {
     const int64_t n_embd_head = hparams.n_embd_head_v;
 
     GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
index aea8b29513e737da6a265474898171674f3142b0..0fdf2d42c252f8cda4e31f50dfa85fd230aedb49 100644 (file)
@@ -1,10 +1,9 @@
-#include "ggml.h"
 #include "models.h"
 
-#define CHUNK_SIZE 64
+#include "llama-memory-recurrent.h"
 
 llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
-    llm_graph_context_mamba(params), model(model) {
+    llm_build_delta_net_base(params), model(model) {
     ggml_tensor * cur;
     ggml_tensor * inpL;
 
@@ -83,326 +82,6 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t
         t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
 }
 
-std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chunking(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * b,
-        ggml_tensor * s,
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(S_k == S_v);
-    GGML_ASSERT(H_v % H_k == 0);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
-
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
-    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
-
-    const float scale = 1.0f / sqrtf(S_k);
-
-    q = ggml_scale(ctx0, q, scale);
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(b, "b_in", il);
-    cb(g, "g_in", il);
-
-    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
-    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
-    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
-    g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [  1, n_tokens, H_v, n_seqs]
-    b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [  1, n_tokens, H_v, n_seqs]
-
-    const int CS = CHUNK_SIZE;
-
-    const int pad = (CS - n_tokens % CS) % CS;
-    const int n_chunks = (n_tokens + pad) / CS;
-
-    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
-    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
-    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
-    g = ggml_pad(ctx0, g, 0, pad, 0, 0);
-    b = ggml_pad(ctx0, b, 0, pad, 0, 0);
-
-    ggml_tensor * v_b = ggml_mul(ctx0, v, b);
-    ggml_tensor * k_b = ggml_mul(ctx0, k, b);
-
-    cb(v_b, "v_b", il);
-    cb(k_b, "k_b", il);
-
-    q   = ggml_reshape_4d(ctx0, q,   S_k, CS, n_chunks, H_k * n_seqs);
-    k   = ggml_reshape_4d(ctx0, k,   S_k, CS, n_chunks, H_k * n_seqs);
-    k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
-    v   = ggml_reshape_4d(ctx0, v,   S_v, CS, n_chunks, H_v * n_seqs);
-    v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
-
-    g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
-    b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
-
-    // [CS, 1, n_chunks, H_v * n_seqs]
-    ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
-    cb(g_cs, "g_cs", il);
-
-    ggml_tensor * g_cs_i = g_cs;
-    ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
-
-    g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
-
-    // [CS, CS, n_chunks, H_v * n_seqs]
-    ggml_tensor * decay_mask;
-    decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
-    decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    cb(decay_mask, "decay_mask", il);
-
-    // [CS, CS, n_chunks, H_k * n_seqs]
-    ggml_tensor * kb;
-    kb = ggml_mul_mat(ctx0, k,  k_b);
-    kb = ggml_mul    (ctx0, kb, decay_mask);
-
-    // [CS, CS, n_chunks, H_k * n_seqs]
-    ggml_tensor * attn;
-    attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
-
-    ggml_tensor * identity;
-    identity = ggml_view_1d(ctx0, attn, CS, 0);
-    identity = ggml_fill   (ctx0, identity, 1.0f);
-    identity = ggml_diag   (ctx0, identity);
-
-    ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
-    cb(lhs, "dnet_add_ch_lhs", il);
-
-    attn = ggml_neg(ctx0, attn);
-
-    ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
-    attn = ggml_add(ctx0, lin_solve, identity);
-    cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
-
-    // [S_v, CS, n_chunks, H_v * n_seqs]
-    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
-
-    // [CS, 1, n_chunks, H_v * n_seqs]
-    ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
-
-    k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
-
-    // [CS, S_k, n_chunks, H_k * n_seqs]
-    ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
-    cb(kbg, "k_beta_g_exp", il);
-
-    // [S_k, CS, n_chunks, H_k * n_seqs]
-    ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
-    cb(k_cd, "k_cumdecay", il);
-
-    // [S_k, CS, n_chunks, H_k * n_seqs]
-    ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
-    ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
-
-    // [CS, CS, n_chunks, H_k * n_seqs]
-    ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
-    kq = ggml_mul(ctx0, kq, decay_mask);
-    kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
-    cb(kq, "kq", il);
-
-    // vectorized calculation of key_gdiff
-    // improved from the chunked version:
-    //   g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
-    //   g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
-    //   key_gdiff = key * g_diff.unsqueeze(-1)
-    //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-    //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-
-    // get last element in g_cumsum along CS dimension (ne0)
-    // example: [[x, y, z, ..., last], ...] -> [[last], ...]
-    // [1, 1, n_chunks, H_v * n_seqs]
-    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
-            g_cs->nb[1],
-            g_cs->nb[2],
-            g_cs->nb[3],
-            ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
-    cb(g_last, "g_last", il);
-
-    // TODO: remove this cont when CUDA supports non-cont unary ops
-    g_last = ggml_cont(ctx0, g_last);
-
-    // [1, 1, n_chunks, H_v * n_seqs]
-    ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
-    cb(g_last_exp, "g_last_exp", il);
-
-    // [CS, 1, n_chunks, H_v * n_seqs]
-    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
-    cb(g_diff, "g_diff", il);
-
-    ggml_tensor * g_diff_exp   = ggml_exp(ctx0, g_diff);
-    ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
-
-    // [S_k, CS, n_chunks, H_v * n_seqs]
-    ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
-    cb(kg, "key_gdiff", il);
-
-    // [CS, S_k, n_chunks, H_v * n_seqs]
-    ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
-    cb(kg_t, "key_gdiff_t", il);
-
-    ggml_tensor * s_t = ggml_transpose(ctx0, s);
-    s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
-    cb(s_t, "dnet_add_ch_state", il);
-
-    // [CS, S_v, n_chunks, H_v * n_seqs]
-    ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
-
-    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
-        ggml_tensor * ch_k_cd    = get_slice_2d(ctx0, k_cd,    chunk); // [S_k,  CS, 1, H_k * n_seqs]
-        ggml_tensor * ch_v_t     = get_slice_2d(ctx0, v_t,     chunk); // [ CS, S_v, 1, H_v * n_seqs]
-        ggml_tensor * ch_kq      = get_slice_2d(ctx0, kq,      chunk); // [ CS,  CS, 1, H_k * n_seqs]
-        ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k,  CS, 1, H_k * n_seqs]
-        ggml_tensor * ch_kg_t    = get_slice_2d(ctx0, kg_t,    chunk); // [ CS, S_k, 1, H_v * n_seqs]
-
-        // [CS, S_v, 1, H_v * n_seqs]
-        ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
-        cb(v_t_p, "v_prime", il);
-
-        // [CS, S_v, 1, H_v * n_seqs]
-        ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
-        cb(v_t_new, "v_t_new", il);
-
-        // [S_v, CS, 1, H_v * n_seqs]
-        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
-        cb(v_attn, "v_attn", il);
-
-        // [S_v, CS, 1, H_v * n_seqs]
-        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
-        cb(attn_inter, "attn_inter", il);
-
-        // [S_v, CS, 1, H_v * n_seqs]
-        ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
-        cb(o_ch, "dnet_add_ch_attn_out", il);
-
-        v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
-
-        // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        // TODO: head broadcast might not work here - probably will need a transpose
-        ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
-
-        // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-        ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
-        s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
-        s_t = ggml_add(ctx0, s_t, kgv);
-        cb(s_t, "dnet_add_ch_state", il);
-    }
-
-    s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
-
-    // truncate padded tokens
-    ggml_tensor * o = ggml_view_4d(ctx0, v,
-            S_v, n_tokens, H_v, n_seqs,
-            ggml_row_size(v->type, S_v),
-            ggml_row_size(v->type, S_v * CS * n_chunks),
-            ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
-
-    o = ggml_permute  (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
-    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
-
-    return {o, s};
-}
-
-std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
-        ggml_tensor * q,
-        ggml_tensor * k,
-        ggml_tensor * v,
-        ggml_tensor * g,
-        ggml_tensor * b, // beta
-        ggml_tensor * s, // state
-        int           il) {
-    const int64_t S_k      = q->ne[0];
-    const int64_t H_k      = q->ne[1];
-    const int64_t n_tokens = q->ne[2];
-    const int64_t n_seqs   = q->ne[3];
-
-    const int64_t S_v = v->ne[0];
-    const int64_t H_v = v->ne[1];
-
-    GGML_ASSERT(n_tokens == 1);
-
-    GGML_ASSERT(S_k == S_v);
-    GGML_ASSERT(H_v % H_k == 0);
-
-    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
-
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
-    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
-
-    const float scale = 1.0f / sqrtf(S_k);
-
-    q = ggml_scale(ctx0, q, scale);
-
-    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
-    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
-    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
-
-    cb(q, "q_in", il);
-    cb(k, "k_in", il);
-    cb(v, "v_in", il);
-    cb(b, "b_in", il);
-    cb(g, "g_in", il);
-
-    g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
-    b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
-
-    // [S_v, S_v, H_v, n_seqs]
-    g = ggml_exp(ctx0, g);
-    s = ggml_mul(ctx0, s, g);
-
-    ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
-
-    // [1, S_v, H_v, n_seqs]
-    ggml_tensor * sk;
-    sk = ggml_mul     (ctx0, s_t, k);
-    sk = ggml_sum_rows(ctx0, sk);
-
-    // [S_v, 1, H_v, n_seqs]
-    ggml_tensor * d;
-    d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
-    d = ggml_mul(ctx0, d, b);
-
-    // [1, S_v, H_v, n_seqs]
-    ggml_tensor * d_t;
-    d_t = ggml_transpose(ctx0, d);
-
-    // [S_v, S_v, H_v, n_seqs]
-    ggml_tensor * kd;
-    k  = ggml_repeat(ctx0, k, s);
-    kd = ggml_mul   (ctx0, k, d_t);
-
-    s_t = ggml_add(ctx0, s_t, kd);
-
-    cb(s_t, "dnet_add_ar_state", il);
-
-    ggml_tensor * s_q = ggml_mul     (ctx0, s_t, q);
-    ggml_tensor * o   = ggml_sum_rows(ctx0, s_q);
-
-    o = ggml_permute  (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
-    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
-
-    return {o, s};
-}
-
 ggml_tensor * llm_build_qwen3next::build_norm_gated(
         ggml_tensor * input,
         ggml_tensor * weights,
index 7beed2daffbdd84c5e604a5e841cc4f9d97b65df..83aeab7280b1cd5cdc3b13dca9723e67b1c9f8bd 100644 (file)
@@ -1,5 +1,7 @@
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params),
     model(model) {}
index cda44653849b88c31b7e240543b64b0423fbd181..7fcab77745c75a887824ac6266e50536eb845ea2 100644 (file)
@@ -1,5 +1,7 @@
 #include "models.h"
 
+#include "llama-memory-recurrent.h"
+
 llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params),
     model(model) {}