models/deci.cpp
models/deepseek.cpp
models/deepseek2.cpp
+ models/delta-net-base.cpp
models/dots1.cpp
models/dream.cpp
models/ernie4-5-moe.cpp
models/ernie4-5.cpp
+ models/exaone-moe.cpp
models/exaone.cpp
models/exaone4.cpp
- models/exaone-moe.cpp
models/falcon-h1.cpp
models/falcon.cpp
models/gemma-embedding.cpp
models/llama-iswa.cpp
models/llama.cpp
models/maincoder.cpp
+ models/mamba-base.cpp
models/mamba.cpp
models/mimo2-iswa.cpp
models/minicpm3.cpp
models/minimax-m2.cpp
+ models/mistral3.cpp
models/modern-bert.cpp
models/mpt.cpp
models/nemotron-h.cpp
models/qwen2moe.cpp
models/qwen2vl.cpp
models/qwen3.cpp
- models/qwen3vl.cpp
- models/qwen3vl-moe.cpp
- models/qwen3moe.cpp
- models/qwen3next.cpp
models/qwen35.cpp
models/qwen35moe.cpp
+ models/qwen3moe.cpp
+ models/qwen3next.cpp
+ models/qwen3vl-moe.cpp
+ models/qwen3vl.cpp
models/refact.cpp
models/rnd1.cpp
models/rwkv6-base.cpp
models/t5-enc.cpp
models/wavtokenizer-dec.cpp
models/xverse.cpp
- models/mistral3.cpp
- models/graph-context-mamba.cpp
)
set_target_properties(llama PROPERTIES
--- /dev/null
+#include "models.h"
+
+#define CHUNK_SIZE 64
+
+// utility to get one slice from the third dimension
+// input dim: [x, y, c, b]
+// output dim: [x, y, 1, b]
+static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
+ return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
+ t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
+}
+
+llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {}
+
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking(
+ ggml_tensor * q,
+ ggml_tensor * k,
+ ggml_tensor * v,
+ ggml_tensor * g,
+ ggml_tensor * b,
+ ggml_tensor * s,
+ int il) {
+ const int64_t S_k = q->ne[0];
+ const int64_t H_k = q->ne[1];
+ const int64_t n_tokens = q->ne[2];
+ const int64_t n_seqs = q->ne[3];
+
+ const int64_t S_v = v->ne[0];
+ const int64_t H_v = v->ne[1];
+
+ GGML_ASSERT(S_k == S_v);
+ GGML_ASSERT(H_v % H_k == 0);
+
+ GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+ GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+ GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+ GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+ GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+ GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
+
+ const float scale = 1.0f / sqrtf(S_k);
+
+ q = ggml_scale(ctx0, q, scale);
+
+ cb(q, "q_in", il);
+ cb(k, "k_in", il);
+ cb(v, "v_in", il);
+ cb(b, "b_in", il);
+ cb(g, "g_in", il);
+
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+ k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+ v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
+ g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs]
+ b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs]
+
+ const int CS = CHUNK_SIZE;
+
+ const int pad = (CS - n_tokens % CS) % CS;
+ const int n_chunks = (n_tokens + pad) / CS;
+
+ q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+ k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+ v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+ g = ggml_pad(ctx0, g, 0, pad, 0, 0);
+ b = ggml_pad(ctx0, b, 0, pad, 0, 0);
+
+ ggml_tensor * v_b = ggml_mul(ctx0, v, b);
+ ggml_tensor * k_b = ggml_mul(ctx0, k, b);
+
+ cb(v_b, "v_b", il);
+ cb(k_b, "k_b", il);
+
+ q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs);
+ k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs);
+ k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
+ v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs);
+ v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
+
+ g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
+ b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
+
+ // [CS, 1, n_chunks, H_v * n_seqs]
+ ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
+ cb(g_cs, "g_cs", il);
+
+ ggml_tensor * g_cs_i = g_cs;
+ ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
+
+ g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
+
+ // [CS, CS, n_chunks, H_v * n_seqs]
+ ggml_tensor * decay_mask;
+ decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
+ decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+ decay_mask = ggml_exp(ctx0, decay_mask);
+ cb(decay_mask, "decay_mask", il);
+
+ // [CS, CS, n_chunks, H_k * n_seqs]
+ ggml_tensor * kb;
+ kb = ggml_mul_mat(ctx0, k, k_b);
+ kb = ggml_mul (ctx0, kb, decay_mask);
+
+ // [CS, CS, n_chunks, H_k * n_seqs]
+ ggml_tensor * attn;
+ attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
+
+ ggml_tensor * identity;
+ identity = ggml_view_1d(ctx0, attn, CS, 0);
+ identity = ggml_fill (ctx0, identity, 1.0f);
+ identity = ggml_diag (ctx0, identity);
+
+ ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
+ cb(lhs, "dnet_add_ch_lhs", il);
+
+ attn = ggml_neg(ctx0, attn);
+
+ ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
+ attn = ggml_add(ctx0, lin_solve, identity);
+ cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
+
+ // [S_v, CS, n_chunks, H_v * n_seqs]
+ v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
+
+ // [CS, 1, n_chunks, H_v * n_seqs]
+ ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
+
+ k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
+
+ // [CS, S_k, n_chunks, H_k * n_seqs]
+ ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
+ cb(kbg, "k_beta_g_exp", il);
+
+ // [S_k, CS, n_chunks, H_k * n_seqs]
+ ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
+ cb(k_cd, "k_cumdecay", il);
+
+ // [S_k, CS, n_chunks, H_k * n_seqs]
+ ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
+ ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
+
+ // [CS, CS, n_chunks, H_k * n_seqs]
+ ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+ kq = ggml_mul(ctx0, kq, decay_mask);
+ kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
+ cb(kq, "kq", il);
+
+ // vectorized calculation of key_gdiff
+ // improved from the chunked version:
+ // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
+ // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
+ // key_gdiff = key * g_diff.unsqueeze(-1)
+ // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+ // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+
+ // get last element in g_cumsum along CS dimension (ne0)
+ // example: [[x, y, z, ..., last], ...] -> [[last], ...]
+ // [1, 1, n_chunks, H_v * n_seqs]
+ ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
+ g_cs->nb[1],
+ g_cs->nb[2],
+ g_cs->nb[3],
+ ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
+ cb(g_last, "g_last", il);
+
+ // TODO: remove this cont when CUDA supports non-cont unary ops
+ g_last = ggml_cont(ctx0, g_last);
+
+ // [1, 1, n_chunks, H_v * n_seqs]
+ ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
+ cb(g_last_exp, "g_last_exp", il);
+
+ // [CS, 1, n_chunks, H_v * n_seqs]
+ ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
+ cb(g_diff, "g_diff", il);
+
+ ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
+ ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
+
+ // [S_k, CS, n_chunks, H_v * n_seqs]
+ ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
+ cb(kg, "key_gdiff", il);
+
+ // [CS, S_k, n_chunks, H_v * n_seqs]
+ ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
+ cb(kg_t, "key_gdiff_t", il);
+
+ ggml_tensor * s_t = ggml_transpose(ctx0, s);
+ s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
+ cb(s_t, "dnet_add_ch_state", il);
+
+ // [CS, S_v, n_chunks, H_v * n_seqs]
+ ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
+
+ for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+ ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs]
+ ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs]
+ ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs]
+ ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs]
+ ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs]
+
+ // [CS, S_v, 1, H_v * n_seqs]
+ ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
+ cb(v_t_p, "v_prime", il);
+
+ // [CS, S_v, 1, H_v * n_seqs]
+ ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
+ cb(v_t_new, "v_t_new", il);
+
+ // [S_v, CS, 1, H_v * n_seqs]
+ ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
+ cb(v_attn, "v_attn", il);
+
+ // [S_v, CS, 1, H_v * n_seqs]
+ ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
+ cb(attn_inter, "attn_inter", il);
+
+ // [S_v, CS, 1, H_v * n_seqs]
+ ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
+ cb(o_ch, "dnet_add_ch_attn_out", il);
+
+ v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
+
+ // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
+ // TODO: head broadcast might not work here - probably will need a transpose
+ ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
+
+ // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
+ ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
+ s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
+ s_t = ggml_add(ctx0, s_t, kgv);
+ cb(s_t, "dnet_add_ch_state", il);
+ }
+
+ s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
+
+ // truncate padded tokens
+ ggml_tensor * o = ggml_view_4d(ctx0, v,
+ S_v, n_tokens, H_v, n_seqs,
+ ggml_row_size(v->type, S_v),
+ ggml_row_size(v->type, S_v * CS * n_chunks),
+ ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
+
+ o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+ s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
+
+ return {o, s};
+}
+
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive(
+ ggml_tensor * q,
+ ggml_tensor * k,
+ ggml_tensor * v,
+ ggml_tensor * g,
+ ggml_tensor * b, // beta
+ ggml_tensor * s, // state
+ int il) {
+ const int64_t S_k = q->ne[0];
+ const int64_t H_k = q->ne[1];
+ const int64_t n_tokens = q->ne[2];
+ const int64_t n_seqs = q->ne[3];
+
+ const int64_t S_v = v->ne[0];
+ const int64_t H_v = v->ne[1];
+
+ GGML_ASSERT(n_tokens == 1);
+
+ GGML_ASSERT(S_k == S_v);
+ GGML_ASSERT(H_v % H_k == 0);
+
+ GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+ GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+ GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+ GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+ GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+ GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
+
+ const float scale = 1.0f / sqrtf(S_k);
+
+ q = ggml_scale(ctx0, q, scale);
+
+ q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+ k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+ v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
+
+ cb(q, "q_in", il);
+ cb(k, "k_in", il);
+ cb(v, "v_in", il);
+ cb(b, "b_in", il);
+ cb(g, "g_in", il);
+
+ g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
+ b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
+
+ // [S_v, S_v, H_v, n_seqs]
+ g = ggml_exp(ctx0, g);
+ s = ggml_mul(ctx0, s, g);
+
+ ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
+
+ // [1, S_v, H_v, n_seqs]
+ ggml_tensor * sk;
+ sk = ggml_mul (ctx0, s_t, k);
+ sk = ggml_sum_rows(ctx0, sk);
+
+ // [S_v, 1, H_v, n_seqs]
+ ggml_tensor * d;
+ d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
+ d = ggml_mul(ctx0, d, b);
+
+ // [1, S_v, H_v, n_seqs]
+ ggml_tensor * d_t;
+ d_t = ggml_transpose(ctx0, d);
+
+ // [S_v, S_v, H_v, n_seqs]
+ ggml_tensor * kd;
+ k = ggml_repeat(ctx0, k, s);
+ kd = ggml_mul (ctx0, k, d_t);
+
+ s_t = ggml_add(ctx0, s_t, kd);
+
+ cb(s_t, "dnet_add_ar_state", il);
+
+ ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
+ ggml_tensor * o = ggml_sum_rows(ctx0, s_q);
+
+ o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+ s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
+
+ return {o, s};
+}
#include "models.h"
-
-
llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) :
- llm_graph_context_mamba(params) {
+ llm_build_mamba_base(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
ggml_tensor * cur;
llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) :
- llm_graph_context_mamba(params) {
+ llm_build_mamba_base(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+++ /dev/null
-#include "models.h"
-
-llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {}
-
-ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp,
- ggml_tensor * cur,
- const llama_model & model,
- const llama_ubatch & ubatch,
- int il) {
- const auto * mctx_cur = inp->mctx;
-
- const auto kv_head = mctx_cur->get_head();
-
- const auto & layer = model.layers[il];
-
- const int64_t d_conv = hparams.ssm_d_conv;
- const int64_t d_inner = hparams.ssm_d_inner;
- const int64_t d_state = hparams.ssm_d_state;
- const int64_t dt_rank = hparams.ssm_dt_rank;
- const int64_t n_head = d_inner;
- const int64_t head_dim = 1;
- const int64_t n_seqs = ubatch.n_seqs;
- // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
- const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
-
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-
- GGML_ASSERT(n_seqs != 0);
- GGML_ASSERT(ubatch.equal_seqs());
- GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
-
- ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
- ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
-
- ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
- conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
-
- // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
- cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
-
- // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
- ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur);
- // split the above in two
- // => {d_inner, n_seq_tokens, n_seqs}
- ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
- ggml_tensor * z =
- ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner * ggml_element_size(xz));
-
- // conv
- {
- // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
- ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
-
- // copy last (d_conv - 1) columns back into the state cache
- ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2],
- n_seq_tokens * (conv_x->nb[0]));
-
- ggml_build_forward_expand(
- gf, ggml_cpy(ctx0, last_conv,
- ggml_view_1d(ctx0, conv_states_all, (d_conv - 1) * (d_inner) * (n_seqs),
- kv_head * (d_conv - 1) * (d_inner) *ggml_element_size(conv_states_all))));
-
- // 1D convolution
- // The equivalent is to make a self-overlapping view of conv_x
- // over d_conv columns at each stride in the 3rd dimension,
- // then element-wise multiply that with the conv1d weight,
- // then sum the elements of each row,
- // (the last two steps are a dot product over rows (also doable with mul_mat))
- // then permute away the ne[0] dimension,
- // and then you're left with the resulting x tensor.
- // For simultaneous sequences, all sequences need to have the same length.
- x = ggml_ssm_conv(ctx0, conv_x, layer.ssm_conv1d);
-
- // bias
- x = ggml_add(ctx0, x, layer.ssm_conv1d_b);
-
- x = ggml_silu(ctx0, x);
- }
-
- // ssm
- {
- // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
- ggml_tensor * x_db = build_lora_mm(layer.ssm_x, x);
- // split
- ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
- ggml_tensor * B =
- ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
- x_db->nb[2], ggml_element_size(x_db) * dt_rank);
- ggml_tensor * C =
- ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
- x_db->nb[2], ggml_element_size(x_db) * (dt_rank + d_state));
-
- // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers
- if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) {
- dt = build_norm(dt, layer.ssm_dt_norm, NULL, LLM_NORM_RMS, il);
- B = build_norm(B, layer.ssm_b_norm, NULL, LLM_NORM_RMS, il);
- C = build_norm(C, layer.ssm_c_norm, NULL, LLM_NORM_RMS, il);
- }
-
- // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
- dt = build_lora_mm(layer.ssm_dt, dt);
- dt = ggml_add(ctx0, dt, layer.ssm_dt_b);
-
- cur = x;
- x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
-
- ggml_tensor * A = layer.ssm_a;
-
- // use the states and the indices provided by build_recurrent_state
- // (this is necessary in order to properly use the states before they are overwritten,
- // while avoiding to make unnecessary copies of the states)
- auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
- ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
-
- // Custom operator to optimize the parallel associative scan
- // as described in the Annex D of the Mamba paper.
- // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
- return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
- };
-
- ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
-
- // store last states
- ggml_build_forward_expand(
- gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, x->nb[3] * x->ne[3]),
- ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
- kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
-
- ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
-
- // TODO: skip computing output earlier for unused tokens
-
- y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
- y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
-
- // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
- cur = build_lora_mm(layer.ssm_out, y);
- }
-
- // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
-
- return cur;
-}
-
-ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp,
- ggml_tensor * cur,
- const llama_model & model,
- const llama_ubatch & ubatch,
- int il) const {
- const auto * mctx_cur = inp->mctx;
-
- const auto kv_head = mctx_cur->get_head();
-
- const int64_t d_conv = hparams.ssm_d_conv;
- const int64_t d_inner = hparams.ssm_d_inner;
- const int64_t d_state = hparams.ssm_d_state;
- const int64_t n_head = hparams.ssm_dt_rank;
- const int64_t head_dim = d_inner / n_head;
- const int64_t n_group = hparams.ssm_n_group;
- const int64_t n_seqs = ubatch.n_seqs;
-
- const int64_t n_seq_tokens = ubatch.n_seq_tokens;
-
- GGML_ASSERT(n_seqs != 0);
- GGML_ASSERT(ubatch.equal_seqs());
- GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
-
- ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
- ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
-
- ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
- conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs);
-
- // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
- cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
-
- // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
-
- // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
- ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
-
- // split the above in three
- ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0],
- zxBCdt->nb[1], zxBCdt->nb[2], 0);
- ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2 * n_group * d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1],
- zxBCdt->nb[2], d_inner * ggml_element_size(zxBCdt));
- ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2],
- (2 * d_inner + 2 * n_group * d_state) * ggml_element_size(zxBCdt));
-
- // conv
- {
- // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
- ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
-
- // copy last (d_conv - 1) columns back into the state cache
- ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs,
- conv_x->nb[1], conv_x->nb[2], n_seq_tokens * (conv_x->nb[0]));
-
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv,
- ggml_view_1d(ctx0, conv_states_all,
- (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs),
- kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) *
- ggml_element_size(conv_states_all))));
-
- // 1D convolution
- // The equivalent is to make a self-overlapping view of conv_x
- // over d_conv columns at each stride in the 3rd dimension,
- // then element-wise multiply that with the conv1d weight,
- // then sum the elements of each row,
- // (the last two steps are a dot product over rows (also doable with mul_mat))
- // then permute away the ne[0] dimension,
- // and then you're left with the resulting x tensor.
- // For simultaneous sequences, all sequences need to have the same length.
- xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
-
- // bias
- xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
-
- xBC = ggml_silu(ctx0, xBC);
- }
-
- // ssm
- {
- // These correspond to V K Q in SSM/attention duality
- ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0],
- xBC->nb[1], xBC->nb[2], 0);
- ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
- xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC));
- ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
- xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC));
-
- // {n_head, n_seq_tokens, n_seqs}
- dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
-
- ggml_tensor * A = model.layers[il].ssm_a;
-
- // use the states and the indices provided by build_recurrent_state
- // (this is necessary in order to properly use the states before they are overwritten,
- // while avoiding to make unnecessary copies of the states)
- auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
- ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
-
- // TODO: use semistructured matrices to implement state-space duality
- // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
- return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
- };
-
- ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
-
- // store last states
- ggml_build_forward_expand(
- gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]),
- ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
- kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
-
- ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1],
- n_seq_tokens * n_head * x->nb[1], 0);
-
- // TODO: skip computing output earlier for unused tokens
-
- y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
- cb(y, "mamba2_y_add_d", il);
- y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
-
- // grouped RMS norm
- if (model.layers[il].ssm_norm) {
- y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
- y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
- }
-
- y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
-
- // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
- cur = build_lora_mm(model.layers[il].ssm_out, y);
- }
-
- // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
- cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
- cb(cur, "mamba_out", il);
-
- return cur;
-}
#include "models.h"
-llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
+llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
ggml_tensor * cur;
#include "models.h"
#include "ggml.h"
+#include "llama-memory-recurrent.h"
+
#define CHUNK_SIZE 64
// Causal Conv1d function for Q,K,V
}
llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
- llm_graph_context_mamba(params), model(model) {
+ llm_build_mamba_base(params), model(model) {
ggml_tensor * cur;
ggml_tensor * inpL;
--- /dev/null
+#include "models.h"
+
+#include "llama-memory-recurrent.h"
+
+llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {}
+
+ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp,
+ ggml_tensor * cur,
+ const llama_model & model,
+ const llama_ubatch & ubatch,
+ int il) {
+ const auto * mctx_cur = inp->mctx;
+
+ const auto kv_head = mctx_cur->get_head();
+
+ const auto & layer = model.layers[il];
+
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t dt_rank = hparams.ssm_dt_rank;
+ const int64_t n_head = d_inner;
+ const int64_t head_dim = 1;
+ const int64_t n_seqs = ubatch.n_seqs;
+ // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
+ const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
+
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+ GGML_ASSERT(n_seqs != 0);
+ GGML_ASSERT(ubatch.equal_seqs());
+ GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+ ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+
+ ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+ conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
+
+ // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+ cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+ // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs}
+ ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur);
+ // split the above in two
+ // => {d_inner, n_seq_tokens, n_seqs}
+ ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0);
+ ggml_tensor * z =
+ ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], d_inner * ggml_element_size(xz));
+
+ // conv
+ {
+ // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
+ ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
+
+ // copy last (d_conv - 1) columns back into the state cache
+ ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, conv_x->nb[1], conv_x->nb[2],
+ n_seq_tokens * (conv_x->nb[0]));
+
+ ggml_build_forward_expand(
+ gf, ggml_cpy(ctx0, last_conv,
+ ggml_view_1d(ctx0, conv_states_all, (d_conv - 1) * (d_inner) * (n_seqs),
+ kv_head * (d_conv - 1) * (d_inner) *ggml_element_size(conv_states_all))));
+
+ // 1D convolution
+ // The equivalent is to make a self-overlapping view of conv_x
+ // over d_conv columns at each stride in the 3rd dimension,
+ // then element-wise multiply that with the conv1d weight,
+ // then sum the elements of each row,
+ // (the last two steps are a dot product over rows (also doable with mul_mat))
+ // then permute away the ne[0] dimension,
+ // and then you're left with the resulting x tensor.
+ // For simultaneous sequences, all sequences need to have the same length.
+ x = ggml_ssm_conv(ctx0, conv_x, layer.ssm_conv1d);
+
+ // bias
+ x = ggml_add(ctx0, x, layer.ssm_conv1d_b);
+
+ x = ggml_silu(ctx0, x);
+ }
+
+ // ssm
+ {
+ // {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs}
+ ggml_tensor * x_db = build_lora_mm(layer.ssm_x, x);
+ // split
+ ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0);
+ ggml_tensor * B =
+ ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
+ x_db->nb[2], ggml_element_size(x_db) * dt_rank);
+ ggml_tensor * C =
+ ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state * x_db->nb[0], x_db->nb[1],
+ x_db->nb[2], ggml_element_size(x_db) * (dt_rank + d_state));
+
+ // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers
+ if (ssm_dt_b_c_rms || (layer.ssm_dt_norm && layer.ssm_b_norm && layer.ssm_c_norm)) {
+ dt = build_norm(dt, layer.ssm_dt_norm, NULL, LLM_NORM_RMS, il);
+ B = build_norm(B, layer.ssm_b_norm, NULL, LLM_NORM_RMS, il);
+ C = build_norm(C, layer.ssm_c_norm, NULL, LLM_NORM_RMS, il);
+ }
+
+ // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs}
+ dt = build_lora_mm(layer.ssm_dt, dt);
+ dt = ggml_add(ctx0, dt, layer.ssm_dt_b);
+
+ cur = x;
+ x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
+
+ ggml_tensor * A = layer.ssm_a;
+
+ // use the states and the indices provided by build_recurrent_state
+ // (this is necessary in order to properly use the states before they are overwritten,
+ // while avoiding to make unnecessary copies of the states)
+ auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
+
+ // Custom operator to optimize the parallel associative scan
+ // as described in the Annex D of the Mamba paper.
+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+ return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
+ };
+
+ ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+
+ // store last states
+ ggml_build_forward_expand(
+ gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, x->nb[3] * x->ne[3]),
+ ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
+ kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
+
+ ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0);
+
+ // TODO: skip computing output earlier for unused tokens
+
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, layer.ssm_d));
+ y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
+
+ // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+ cur = build_lora_mm(layer.ssm_out, y);
+ }
+
+ // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+
+ return cur;
+}
+
+ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
+ ggml_tensor * cur,
+ const llama_model & model,
+ const llama_ubatch & ubatch,
+ int il) const {
+ const auto * mctx_cur = inp->mctx;
+
+ const auto kv_head = mctx_cur->get_head();
+
+ const int64_t d_conv = hparams.ssm_d_conv;
+ const int64_t d_inner = hparams.ssm_d_inner;
+ const int64_t d_state = hparams.ssm_d_state;
+ const int64_t n_head = hparams.ssm_dt_rank;
+ const int64_t head_dim = d_inner / n_head;
+ const int64_t n_group = hparams.ssm_n_group;
+ const int64_t n_seqs = ubatch.n_seqs;
+
+ const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+ GGML_ASSERT(n_seqs != 0);
+ GGML_ASSERT(ubatch.equal_seqs());
+ GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+ ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+ ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+
+ ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
+ conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs);
+
+ // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
+ cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+ // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
+
+ // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
+ ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
+
+ // split the above in three
+ ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0],
+ zxBCdt->nb[1], zxBCdt->nb[2], 0);
+ ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2 * n_group * d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1],
+ zxBCdt->nb[2], d_inner * ggml_element_size(zxBCdt));
+ ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2],
+ (2 * d_inner + 2 * n_group * d_state) * ggml_element_size(zxBCdt));
+
+ // conv
+ {
+ // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
+ ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
+
+ // copy last (d_conv - 1) columns back into the state cache
+ ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2 * n_group * d_state, n_seqs,
+ conv_x->nb[1], conv_x->nb[2], n_seq_tokens * (conv_x->nb[0]));
+
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv,
+ ggml_view_1d(ctx0, conv_states_all,
+ (d_conv - 1) * (d_inner + 2 * n_group * d_state) * (n_seqs),
+ kv_head * (d_conv - 1) * (d_inner + 2 * n_group * d_state) *
+ ggml_element_size(conv_states_all))));
+
+ // 1D convolution
+ // The equivalent is to make a self-overlapping view of conv_x
+ // over d_conv columns at each stride in the 3rd dimension,
+ // then element-wise multiply that with the conv1d weight,
+ // then sum the elements of each row,
+ // (the last two steps are a dot product over rows (also doable with mul_mat))
+ // then permute away the ne[0] dimension,
+ // and then you're left with the resulting x tensor.
+ // For simultaneous sequences, all sequences need to have the same length.
+ xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
+
+ // bias
+ xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
+
+ xBC = ggml_silu(ctx0, xBC);
+ }
+
+ // ssm
+ {
+ // These correspond to V K Q in SSM/attention duality
+ ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * xBC->nb[0],
+ xBC->nb[1], xBC->nb[2], 0);
+ ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
+ xBC->nb[1], xBC->nb[2], d_inner * ggml_element_size(xBC));
+ ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state * xBC->nb[0],
+ xBC->nb[1], xBC->nb[2], (d_inner + n_group * d_state) * ggml_element_size(xBC));
+
+ // {n_head, n_seq_tokens, n_seqs}
+ dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
+
+ ggml_tensor * A = model.layers[il].ssm_a;
+
+ // use the states and the indices provided by build_recurrent_state
+ // (this is necessary in order to properly use the states before they are overwritten,
+ // while avoiding to make unnecessary copies of the states)
+ auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
+ ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, mctx_cur->get_size());
+
+ // TODO: use semistructured matrices to implement state-space duality
+ // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
+ return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
+ };
+
+ ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
+
+ // store last states
+ ggml_build_forward_expand(
+ gf, ggml_cpy(ctx0, ggml_view_1d(ctx0, y_ssm, d_state * d_inner * n_seqs, ggml_nelements(x) * x->nb[0]),
+ ggml_view_1d(ctx0, ssm_states_all, d_state * d_inner * n_seqs,
+ kv_head * d_state * d_inner * ggml_element_size(ssm_states_all))));
+
+ ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head * x->nb[1],
+ n_seq_tokens * n_head * x->nb[1], 0);
+
+ // TODO: skip computing output earlier for unused tokens
+
+ y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
+ cb(y, "mamba2_y_add_d", il);
+ y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
+
+ // grouped RMS norm
+ if (model.layers[il].ssm_norm) {
+ y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
+ y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
+ }
+
+ y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
+
+ // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
+ cur = build_lora_mm(model.layers[il].ssm_out, y);
+ }
+
+ // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
+ cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+ cb(cur, "mamba_out", il);
+
+ return cur;
+}
#include "models.h"
-
-llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) {
+llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) {
ggml_tensor * cur;
ggml_tensor * inpL;
#pragma once
-#include "../llama-model.h"
-#include "../llama-graph.h"
+#include "llama-model.h"
+#include "llama-graph.h"
-// TODO: remove in follow-up PR - move to .cpp files
-#include "../llama-memory-recurrent.h"
+// note: almost all graphs require atleast sqrtf, so include cmath globally
#include <cmath>
-struct llm_graph_context_mamba : public llm_graph_context {
- llm_graph_context_mamba(const llm_graph_params & params);
+//
+// base classes
+//
- virtual ~llm_graph_context_mamba() = default;
+struct llm_build_mamba_base : public llm_graph_context {
+ llm_build_mamba_base(const llm_graph_params & params);
+
+ virtual ~llm_build_mamba_base() = default;
ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const;
};
-// Base class for RWKV-related models
+struct llm_build_delta_net_base : public llm_graph_context {
+ llm_build_delta_net_base(const llm_graph_params & params);
+
+ virtual ~llm_build_delta_net_base() = default;
+
+ // returns pair of output and new state
+ std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
+ ggml_tensor * q,
+ ggml_tensor * k,
+ ggml_tensor * v,
+ ggml_tensor * g,
+ ggml_tensor * b,
+ ggml_tensor * s,
+ int il);
+
+ // returns pair of output and new state
+ std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
+ ggml_tensor * q,
+ ggml_tensor * k,
+ ggml_tensor * v,
+ ggml_tensor * g,
+ ggml_tensor * b,
+ ggml_tensor * s,
+ int il);
+};
+
struct llm_build_rwkv6_base : public llm_graph_context {
const llama_model & model;
int il) const;
};
+//
+// models
+//
+
struct llm_build_afmoe : public llm_graph_context {
llm_build_afmoe(const llama_model & model, const llm_graph_params & params);
};
llm_build_falcon(const llama_model & model, const llm_graph_params & params);
};
-struct llm_build_falcon_h1 : public llm_graph_context_mamba {
+struct llm_build_falcon_h1 : public llm_build_mamba_base {
llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params);
};
const int il);
};
-struct llm_build_granite_hybrid : public llm_graph_context_mamba {
+struct llm_build_granite_hybrid : public llm_build_mamba_base {
llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params);
ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il);
ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn,
llm_build_jais(const llama_model & model, const llm_graph_params & params);
};
-struct llm_build_jamba : public llm_graph_context_mamba {
+struct llm_build_jamba : public llm_build_mamba_base {
llm_build_jamba(const llama_model & model, const llm_graph_params & params);
};
-struct llm_build_kimi_linear : public llm_graph_context_mamba {
+// TODO: derive llm_build_delta_net_base instead
+struct llm_build_kimi_linear : public llm_build_mamba_base {
llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
std::pair<ggml_tensor *, ggml_tensor *> build_kda_autoregressive(
llm_build_maincoder(const llama_model & model, const llm_graph_params & params);
};
-struct llm_build_mamba : public llm_graph_context_mamba {
+struct llm_build_mamba : public llm_build_mamba_base {
llm_build_mamba(const llama_model & model, const llm_graph_params & params);
};
llm_build_nemotron(const llama_model & model, const llm_graph_params & params);
};
-struct llm_build_nemotron_h : public llm_graph_context_mamba {
+struct llm_build_nemotron_h : public llm_build_mamba_base {
llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params);
- ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il);
+ ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il);
ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn,
- const llama_model & model, const int64_t n_embd_head, const int il);
+ const llama_model & model, int64_t n_embd_head, int il);
};
struct llm_build_neo_bert : public llm_graph_context {
llm_build_phi3(const llama_model & model, const llm_graph_params & params);
};
-struct llm_build_plamo2 : public llm_graph_context_mamba {
+struct llm_build_plamo2 : public llm_build_mamba_base {
llm_build_plamo2(const llama_model & model, const llm_graph_params & params);
private:
ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il);
llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
};
-struct llm_build_qwen3next : public llm_graph_context_mamba {
+struct llm_build_qwen3next : public llm_build_delta_net_base {
llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
private:
ggml_tensor * build_layer_attn(
ggml_tensor * cur,
int il);
- // returns pair of output and new state
- std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
- ggml_tensor * q,
- ggml_tensor * k,
- ggml_tensor * v,
- ggml_tensor * g,
- ggml_tensor * beta,
- ggml_tensor * state,
- int il);
-
- // returns pair of output and new state
- std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
- ggml_tensor * q,
- ggml_tensor * k,
- ggml_tensor * v,
- ggml_tensor * g,
- ggml_tensor * beta,
- ggml_tensor * state,
- int il);
-
ggml_tensor * build_norm_gated(
ggml_tensor * input,
ggml_tensor * weights,
const llama_model & model;
};
-struct llm_build_qwen35 : public llm_graph_context_mamba {
+// TODO: derive llm_build_delta_net_base instead
+struct llm_build_qwen35 : public llm_graph_context {
llm_build_qwen35(const llama_model & model, const llm_graph_params & params);
private:
ggml_tensor * build_layer_attn(
ggml_tensor * diag_mask,
int il);
+
ggml_tensor * build_layer_ffn(
ggml_tensor * cur,
int il);
const llama_model & model;
};
-struct llm_build_qwen35moe : public llm_graph_context_mamba {
+// TODO: derive llm_build_delta_net_base instead
+struct llm_build_qwen35moe : public llm_graph_context {
llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params);
private:
ggml_tensor * build_layer_attn(
#include "models.h"
-
-
llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) :
- llm_graph_context_mamba(params) {
+ llm_build_mamba_base(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur,
llm_graph_input_attn_kv * inp_attn,
const llama_model & model,
- const int64_t n_embd_head,
- const int il) {
+ int64_t n_embd_head,
+ int il) {
// compute Q and K
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
return cur;
}
-ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) {
+ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) {
if (model.layers[il].ffn_gate_inp == nullptr) {
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
#include "models.h"
+#include "llama-memory-recurrent.h"
+
llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) :
- llm_graph_context_mamba(params) {
+ llm_build_mamba_base(params) {
ggml_tensor * cur;
ggml_tensor * inpL;
-#include "ggml.h"
#include "models.h"
+#include "llama-memory-recurrent.h"
+
#define CHUNK_SIZE 64
llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) :
- llm_graph_context_mamba(params), model(model) {
+ llm_graph_context(params), model(model) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-#include "ggml.h"
#include "models.h"
+#include "llama-memory-recurrent.h"
+
#define CHUNK_SIZE 64
llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) :
- llm_graph_context_mamba(params), model(model) {
+ llm_graph_context(params), model(model) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-#include "ggml.h"
#include "models.h"
-#define CHUNK_SIZE 64
+#include "llama-memory-recurrent.h"
llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) :
- llm_graph_context_mamba(params), model(model) {
+ llm_build_delta_net_base(params), model(model) {
ggml_tensor * cur;
ggml_tensor * inpL;
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
}
-std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chunking(
- ggml_tensor * q,
- ggml_tensor * k,
- ggml_tensor * v,
- ggml_tensor * g,
- ggml_tensor * b,
- ggml_tensor * s,
- int il) {
- const int64_t S_k = q->ne[0];
- const int64_t H_k = q->ne[1];
- const int64_t n_tokens = q->ne[2];
- const int64_t n_seqs = q->ne[3];
-
- const int64_t S_v = v->ne[0];
- const int64_t H_v = v->ne[1];
-
- GGML_ASSERT(S_k == S_v);
- GGML_ASSERT(H_v % H_k == 0);
-
- GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
- GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
- GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
-
- GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
- GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
- GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
-
- const float scale = 1.0f / sqrtf(S_k);
-
- q = ggml_scale(ctx0, q, scale);
-
- cb(q, "q_in", il);
- cb(k, "k_in", il);
- cb(v, "v_in", il);
- cb(b, "b_in", il);
- cb(g, "g_in", il);
-
- q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
- k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
- v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
- g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs]
- b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs]
-
- const int CS = CHUNK_SIZE;
-
- const int pad = (CS - n_tokens % CS) % CS;
- const int n_chunks = (n_tokens + pad) / CS;
-
- q = ggml_pad(ctx0, q, 0, pad, 0, 0);
- k = ggml_pad(ctx0, k, 0, pad, 0, 0);
- v = ggml_pad(ctx0, v, 0, pad, 0, 0);
- g = ggml_pad(ctx0, g, 0, pad, 0, 0);
- b = ggml_pad(ctx0, b, 0, pad, 0, 0);
-
- ggml_tensor * v_b = ggml_mul(ctx0, v, b);
- ggml_tensor * k_b = ggml_mul(ctx0, k, b);
-
- cb(v_b, "v_b", il);
- cb(k_b, "k_b", il);
-
- q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs);
- k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs);
- k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
- v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs);
- v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
-
- g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
- b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
-
- // [CS, 1, n_chunks, H_v * n_seqs]
- ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
- cb(g_cs, "g_cs", il);
-
- ggml_tensor * g_cs_i = g_cs;
- ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
-
- g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
-
- // [CS, CS, n_chunks, H_v * n_seqs]
- ggml_tensor * decay_mask;
- decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
- decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
- decay_mask = ggml_exp(ctx0, decay_mask);
- cb(decay_mask, "decay_mask", il);
-
- // [CS, CS, n_chunks, H_k * n_seqs]
- ggml_tensor * kb;
- kb = ggml_mul_mat(ctx0, k, k_b);
- kb = ggml_mul (ctx0, kb, decay_mask);
-
- // [CS, CS, n_chunks, H_k * n_seqs]
- ggml_tensor * attn;
- attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
-
- ggml_tensor * identity;
- identity = ggml_view_1d(ctx0, attn, CS, 0);
- identity = ggml_fill (ctx0, identity, 1.0f);
- identity = ggml_diag (ctx0, identity);
-
- ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
- cb(lhs, "dnet_add_ch_lhs", il);
-
- attn = ggml_neg(ctx0, attn);
-
- ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
- attn = ggml_add(ctx0, lin_solve, identity);
- cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
-
- // [S_v, CS, n_chunks, H_v * n_seqs]
- v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
-
- // [CS, 1, n_chunks, H_v * n_seqs]
- ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
-
- k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
-
- // [CS, S_k, n_chunks, H_k * n_seqs]
- ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
- cb(kbg, "k_beta_g_exp", il);
-
- // [S_k, CS, n_chunks, H_k * n_seqs]
- ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
- cb(k_cd, "k_cumdecay", il);
-
- // [S_k, CS, n_chunks, H_k * n_seqs]
- ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
- ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
-
- // [CS, CS, n_chunks, H_k * n_seqs]
- ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
- kq = ggml_mul(ctx0, kq, decay_mask);
- kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
- cb(kq, "kq", il);
-
- // vectorized calculation of key_gdiff
- // improved from the chunked version:
- // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
- // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
- // key_gdiff = key * g_diff.unsqueeze(-1)
- // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
- // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-
- // get last element in g_cumsum along CS dimension (ne0)
- // example: [[x, y, z, ..., last], ...] -> [[last], ...]
- // [1, 1, n_chunks, H_v * n_seqs]
- ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
- g_cs->nb[1],
- g_cs->nb[2],
- g_cs->nb[3],
- ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
- cb(g_last, "g_last", il);
-
- // TODO: remove this cont when CUDA supports non-cont unary ops
- g_last = ggml_cont(ctx0, g_last);
-
- // [1, 1, n_chunks, H_v * n_seqs]
- ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
- cb(g_last_exp, "g_last_exp", il);
-
- // [CS, 1, n_chunks, H_v * n_seqs]
- ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
- cb(g_diff, "g_diff", il);
-
- ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
- ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
-
- // [S_k, CS, n_chunks, H_v * n_seqs]
- ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
- cb(kg, "key_gdiff", il);
-
- // [CS, S_k, n_chunks, H_v * n_seqs]
- ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
- cb(kg_t, "key_gdiff_t", il);
-
- ggml_tensor * s_t = ggml_transpose(ctx0, s);
- s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
- cb(s_t, "dnet_add_ch_state", il);
-
- // [CS, S_v, n_chunks, H_v * n_seqs]
- ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
-
- for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
- ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs]
- ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs]
- ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs]
- ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs]
- ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs]
-
- // [CS, S_v, 1, H_v * n_seqs]
- ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
- cb(v_t_p, "v_prime", il);
-
- // [CS, S_v, 1, H_v * n_seqs]
- ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
- cb(v_t_new, "v_t_new", il);
-
- // [S_v, CS, 1, H_v * n_seqs]
- ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
- cb(v_attn, "v_attn", il);
-
- // [S_v, CS, 1, H_v * n_seqs]
- ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
- cb(attn_inter, "attn_inter", il);
-
- // [S_v, CS, 1, H_v * n_seqs]
- ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
- cb(o_ch, "dnet_add_ch_attn_out", il);
-
- v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
-
- // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
- // TODO: head broadcast might not work here - probably will need a transpose
- ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
-
- // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
- ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
- s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
- s_t = ggml_add(ctx0, s_t, kgv);
- cb(s_t, "dnet_add_ch_state", il);
- }
-
- s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
-
- // truncate padded tokens
- ggml_tensor * o = ggml_view_4d(ctx0, v,
- S_v, n_tokens, H_v, n_seqs,
- ggml_row_size(v->type, S_v),
- ggml_row_size(v->type, S_v * CS * n_chunks),
- ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
-
- o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
- s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
-
- return {o, s};
-}
-
-std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
- ggml_tensor * q,
- ggml_tensor * k,
- ggml_tensor * v,
- ggml_tensor * g,
- ggml_tensor * b, // beta
- ggml_tensor * s, // state
- int il) {
- const int64_t S_k = q->ne[0];
- const int64_t H_k = q->ne[1];
- const int64_t n_tokens = q->ne[2];
- const int64_t n_seqs = q->ne[3];
-
- const int64_t S_v = v->ne[0];
- const int64_t H_v = v->ne[1];
-
- GGML_ASSERT(n_tokens == 1);
-
- GGML_ASSERT(S_k == S_v);
- GGML_ASSERT(H_v % H_k == 0);
-
- GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
- GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
- GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
-
- GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
- GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
- GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
-
- const float scale = 1.0f / sqrtf(S_k);
-
- q = ggml_scale(ctx0, q, scale);
-
- q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
- k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
- v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
-
- cb(q, "q_in", il);
- cb(k, "k_in", il);
- cb(v, "v_in", il);
- cb(b, "b_in", il);
- cb(g, "g_in", il);
-
- g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
- b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
-
- // [S_v, S_v, H_v, n_seqs]
- g = ggml_exp(ctx0, g);
- s = ggml_mul(ctx0, s, g);
-
- ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
-
- // [1, S_v, H_v, n_seqs]
- ggml_tensor * sk;
- sk = ggml_mul (ctx0, s_t, k);
- sk = ggml_sum_rows(ctx0, sk);
-
- // [S_v, 1, H_v, n_seqs]
- ggml_tensor * d;
- d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
- d = ggml_mul(ctx0, d, b);
-
- // [1, S_v, H_v, n_seqs]
- ggml_tensor * d_t;
- d_t = ggml_transpose(ctx0, d);
-
- // [S_v, S_v, H_v, n_seqs]
- ggml_tensor * kd;
- k = ggml_repeat(ctx0, k, s);
- kd = ggml_mul (ctx0, k, d_t);
-
- s_t = ggml_add(ctx0, s_t, kd);
-
- cb(s_t, "dnet_add_ar_state", il);
-
- ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
- ggml_tensor * o = ggml_sum_rows(ctx0, s_q);
-
- o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
- s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
-
- return {o, s};
-}
-
ggml_tensor * llm_build_qwen3next::build_norm_gated(
ggml_tensor * input,
ggml_tensor * weights,
#include "models.h"
+#include "llama-memory-recurrent.h"
+
llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params),
model(model) {}
#include "models.h"
+#include "llama-memory-recurrent.h"
+
llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params),
model(model) {}