]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
models : optimize qwen3next graph (#19375)
authorGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 10:57:36 +0000 (12:57 +0200)
committerGitHub <redacted>
Sat, 14 Feb 2026 10:57:36 +0000 (12:57 +0200)
* models : optimizing qwen3next graph

* cont

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* wip

* cont : remove redundant q, g chunking

* minor

* minor

* avoid passing masks around

* avoid concats during chunking

* naming + shapes

* update names and use prefix to disable CUDA graphs

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-metal/ggml-metal-common.cpp
src/models/models.h
src/models/qwen3next.cpp

index 85ce96958fa0c050e4a5349d7366bbf7bbf250f4..bed5c71a1bdaf2dd3510cacb87f49fc81ebdc48b 100644 (file)
@@ -2872,6 +2872,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
     const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
     const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
     const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
+    const std::string delta_net_prefix = "dnet_add";
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
@@ -2902,7 +2903,8 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
             strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
             strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
             strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
-            strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
+            strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 &&
+            strncmp(node->name, delta_net_prefix.c_str(), delta_net_prefix.size()) != 0) {
             // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
             // by means of matching node names. See
             // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
@@ -4544,6 +4546,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_CEIL:
                 case GGML_UNARY_OP_ROUND:
                 case GGML_UNARY_OP_TRUNC:
+                    // TODO: should become:
+                    //return ggml_is_contiguous_rows(op->src[0]);
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
index 87e13786849001c2e7c361d8292da9e6e53679f3..2eb9820bff91c7e753b7cc72f03a72dd88449199 100644 (file)
@@ -273,6 +273,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
             case GGML_OP_DIAG:
             case GGML_OP_MUL:
             case GGML_OP_ADD:
+            case GGML_OP_SUB:
             case GGML_OP_DIV:
             case GGML_OP_GLU:
             case GGML_OP_SCALE:
index 3c66d325314fda9e1555178b84f0f9ad3670c355..ec6f80e5265949753e089ac0a1d7d2bdb20df14d 100644 (file)
@@ -489,9 +489,6 @@ private:
     ggml_tensor * build_layer_attn_linear(
          llm_graph_input_rs * inp,
                 ggml_tensor * cur,
-                ggml_tensor * causal_mask,
-                ggml_tensor * identity,
-                ggml_tensor * diag_mask,
                         int   il);
 
     ggml_tensor * build_layer_ffn(
@@ -506,9 +503,6 @@ private:
                 ggml_tensor * g,
                 ggml_tensor * beta,
                 ggml_tensor * state,
-                ggml_tensor * causal_mask,
-                ggml_tensor * identity,
-                ggml_tensor * diag_mask,
                         int   il);
 
     // returns pair of output and new state
index 99b1a76a485e8cd373cad6bf60fc5b1753d5a74b..aea8b29513e737da6a265474898171674f3142b0 100644 (file)
@@ -16,17 +16,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
     ggml_tensor * inp_pos     = build_inp_pos();
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
-    ggml_tensor * causal_mask =
-        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
-                    GGML_TRI_TYPE_LOWER);
-
-    ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
-    ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
-
-    ggml_build_forward_expand(gf, causal_mask);
-    ggml_build_forward_expand(gf, identity);
-    ggml_build_forward_expand(gf, diag_mask);
-
     for (int il = 0; il < n_layer; ++il) {
         ggml_tensor * inpSA = inpL;
 
@@ -36,7 +25,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
         // Determine layer type and build appropriate attention mechanism
         if (hparams.is_recurrent(il)) {
             // Linear attention layer (gated delta net)
-            cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
+            cur = build_layer_attn_linear(inp->get_recr(), cur, il);
         } else {
             // Full attention layer
             cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
@@ -99,11 +88,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
         ggml_tensor * k,
         ggml_tensor * v,
         ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
-        ggml_tensor * causal_mask,
-        ggml_tensor * identity,
-        ggml_tensor * diag_mask,
+        ggml_tensor * b,
+        ggml_tensor * s,
         int           il) {
     const int64_t S_k      = q->ne[0];
     const int64_t H_k      = q->ne[1];
@@ -113,134 +99,123 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
     const int64_t S_v = v->ne[0];
     const int64_t H_v = v->ne[1];
 
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
 
     GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
 
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
-
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
+    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
 
-    const float scale = 1.0f / sqrtf(S_v);
+    const float scale = 1.0f / sqrtf(S_k);
 
     q = ggml_scale(ctx0, q, scale);
 
-    beta = ggml_sigmoid(ctx0, beta);
-
     cb(q, "q_in", il);
     cb(k, "k_in", il);
     cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
+    cb(b, "b_in", il);
     cb(g, "g_in", il);
 
-    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
-    g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
-
-    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
-
-    cb(q, "q_perm", il);
-    cb(k, "k_perm", il);
-    cb(v, "v_perm", il);
-    cb(beta, "beta_perm", il);
-    cb(g, "g_perm", il);
-    cb(state, "state_in", il);
-
-    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
-    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
-    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
-    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
+    g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [  1, n_tokens, H_v, n_seqs]
+    b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [  1, n_tokens, H_v, n_seqs]
 
-    // Do padding
-    const int64_t chunk_size = CHUNK_SIZE;
+    const int CS = CHUNK_SIZE;
 
-    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
-    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
+    const int pad = (CS - n_tokens % CS) % CS;
+    const int n_chunks = (n_tokens + pad) / CS;
 
     q = ggml_pad(ctx0, q, 0, pad, 0, 0);
     k = ggml_pad(ctx0, k, 0, pad, 0, 0);
     v = ggml_pad(ctx0, v, 0, pad, 0, 0);
-    g = ggml_pad(ctx0, g, pad, 0, 0, 0);
-    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
+    g = ggml_pad(ctx0, g, 0, pad, 0, 0);
+    b = ggml_pad(ctx0, b, 0, pad, 0, 0);
 
-    cb(q, "q_pad", il);
-    cb(k, "k_pad", il);
-    cb(v, "v_pad", il);
-    cb(beta, "beta_pad", il);
-    cb(g, "g_pad", il);
+    ggml_tensor * v_b = ggml_mul(ctx0, v, b);
+    ggml_tensor * k_b = ggml_mul(ctx0, k, b);
 
-    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
-    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
+    cb(v_b, "v_b", il);
+    cb(k_b, "k_b", il);
 
-    cb(v_beta, "v_beta", il);
-    cb(k_beta, "k_beta", il);
+    q   = ggml_reshape_4d(ctx0, q,   S_k, CS, n_chunks, H_k * n_seqs);
+    k   = ggml_reshape_4d(ctx0, k,   S_k, CS, n_chunks, H_k * n_seqs);
+    k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs);
+    v   = ggml_reshape_4d(ctx0, v,   S_v, CS, n_chunks, H_v * n_seqs);
+    v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
 
-    q      = ggml_reshape_4d(ctx0, q,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k      = ggml_reshape_4d(ctx0, k,      S_k, chunk_size, n_chunks, H_k * n_seqs);
-    k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
-    v      = ggml_reshape_4d(ctx0, v,      S_v, chunk_size, n_chunks, H_v * n_seqs);
-    v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
+    g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
 
-    g    = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
-    beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
+    // [CS, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
+    cb(g_cs, "g_cs", il);
 
-    ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
-    cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
+    ggml_tensor * g_cs_i = g_cs;
+    ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
 
-    ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
-    ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
+    g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
 
-    ggml_tensor * gcs_j_broadcast =
-        ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
+    // [CS, CS, n_chunks, H_v * n_seqs]
+    ggml_tensor * decay_mask;
+    decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
+    decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
+    decay_mask = ggml_exp(ctx0, decay_mask);
+    cb(decay_mask, "decay_mask", il);
 
-    ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
-    cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
+    // [CS, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * kb;
+    kb = ggml_mul_mat(ctx0, k,  k_b);
+    kb = ggml_mul    (ctx0, kb, decay_mask);
 
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
-    decay_mask = ggml_exp(ctx0, decay_mask);
-    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+    // [CS, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * attn;
+    attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
 
-    ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
+    ggml_tensor * identity;
+    identity = ggml_view_1d(ctx0, attn, CS, 0);
+    identity = ggml_fill   (ctx0, identity, 1.0f);
+    identity = ggml_diag   (ctx0, identity);
 
-    ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
-    ggml_tensor * attn    = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
-    cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
+    ggml_tensor * lhs = ggml_add(ctx0, attn, identity);
+    cb(lhs, "dnet_add_ch_lhs", il);
 
-    ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
-    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
+    attn = ggml_neg(ctx0, attn);
 
-    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
-    attn                     = ggml_mul(ctx0, lin_solve, causal_mask);
-    attn                     = ggml_add(ctx0, attn, identity);
-    cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
+    ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
+    attn = ggml_add(ctx0, lin_solve, identity);
+    cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs]
 
-    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
+    // [S_v, CS, n_chunks, H_v * n_seqs]
+    v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
 
-    ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
-    ggml_tensor * gexp       = ggml_exp(ctx0, g_cumsum_t);
+    // [CS, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
 
-    ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
-    cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
+    k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
 
-    ggml_tensor * k_cumdecay =
-        ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
-    cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
+    // [CS, S_k, n_chunks, H_k * n_seqs]
+    ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp);
+    cb(kbg, "k_beta_g_exp", il);
 
-    ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
-    attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
-    attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
-    cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
+    // [S_k, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
+    cb(k_cd, "k_cumdecay", il);
 
+    // [S_k, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
+    ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
+
+    // [CS, CS, n_chunks, H_k * n_seqs]
+    ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+    kq = ggml_mul(ctx0, kq, decay_mask);
+    kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
+    cb(kq, "kq", il);
 
     // vectorized calculation of key_gdiff
     // improved from the chunked version:
@@ -250,109 +225,98 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
     //   kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
     //   last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
 
-    // get last element in g_cumsum along chunk_size dimension (ne0)
+    // get last element in g_cumsum along CS dimension (ne0)
     // example: [[x, y, z, ..., last], ...] -> [[last], ...]
-    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
-                                        g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
-                                        (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
+    // [1, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
+            g_cs->nb[1],
+            g_cs->nb[2],
+            g_cs->nb[3],
+            ggml_row_size(g_cs->type, g_cs->ne[0] - 1));
+    cb(g_last, "g_last", il);
+
+    // TODO: remove this cont when CUDA supports non-cont unary ops
     g_last = ggml_cont(ctx0, g_last);
-    cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
 
+    // [1, 1, n_chunks, H_v * n_seqs]
     ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
-    cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
-
-    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
-    cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
+    cb(g_last_exp, "g_last_exp", il);
 
-    ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-    ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
-                                                 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
+    // [CS, 1, n_chunks, H_v * n_seqs]
+    ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
+    cb(g_diff, "g_diff", il);
 
-    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
-    cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
+    ggml_tensor * g_diff_exp   = ggml_exp(ctx0, g_diff);
+    ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
 
-    ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
-    cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
+    // [S_k, CS, n_chunks, H_v * n_seqs]
+    ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
+    cb(kg, "key_gdiff", il);
 
+    // [CS, S_k, n_chunks, H_v * n_seqs]
+    ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
+    cb(kg_t, "key_gdiff_t", il);
 
-    // state to be updated per chunk
-    ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
-    cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
+    ggml_tensor * s_t = ggml_transpose(ctx0, s);
+    s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
+    cb(s_t, "dnet_add_ch_state", il);
 
-    // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
-    ggml_tensor * core_attn_out = nullptr;
+    // [CS, S_v, n_chunks, H_v * n_seqs]
+    ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
 
     for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
-        // shape: (S_k, chunk_size, 1, H_k * n_seqs)
-        ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
+        ggml_tensor * ch_k_cd    = get_slice_2d(ctx0, k_cd,    chunk); // [S_k,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_v_t     = get_slice_2d(ctx0, v_t,     chunk); // [ CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * ch_kq      = get_slice_2d(ctx0, kq,      chunk); // [ CS,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k,  CS, 1, H_k * n_seqs]
+        ggml_tensor * ch_kg_t    = get_slice_2d(ctx0, kg_t,    chunk); // [ CS, S_k, 1, H_v * n_seqs]
 
-        // shape: (S_v, chunk_size, 1, H_v * n_seqs)
-        ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
+        // [CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
+        cb(v_t_p, "v_prime", il);
 
-        // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
-        ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
+        // [CS, S_v, 1, H_v * n_seqs]
+        ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p);
+        cb(v_t_new, "v_t_new", il);
 
-        // shape: (chunk_size, 1, H_v * n_seqs)
-        ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq);
+        cb(v_attn, "v_attn", il);
 
-        // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
-        // replaced by precomputed attn_kq
-        ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
-        cb(attn_chunk, "attn_chunk", il);
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
+        cb(attn_inter, "attn_inter", il);
 
-        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
+        // [S_v, CS, 1, H_v * n_seqs]
+        ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn);
+        cb(o_ch, "dnet_add_ch_attn_out", il);
 
-        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
-        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
-        cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
-
-        // v_new = v_i - v_prime
-        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
-        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
-        cb(v_new, "v_new_chunk", il);
-
-        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
-        ggml_tensor * q_g_exp    = ggml_mul(ctx0, q_chunk, gexp_chunk);
-        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
-        cb(attn_inter, "attn_inter_chunk", il);
-
-        // core_attn_out[:, :, i] = attn_inter + attn @ v_new
-        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
-        cb(v_attn, "v_attn_chunk", il);
-
-        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
-        cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
-
-        core_attn_out = core_attn_out == nullptr
-            ? core_attn_out_chunk
-            : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
+        v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]);
 
         // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
-        //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
+        // TODO: head broadcast might not work here - probably will need a transpose
+        ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
 
         // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
-        ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
-        new_state = ggml_add(ctx0,
-            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
-            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
+        ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
+        s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
+        s_t = ggml_add(ctx0, s_t, kgv);
+        cb(s_t, "dnet_add_ch_state", il);
     }
 
+    s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
+
     // truncate padded tokens
-    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
+    ggml_tensor * o = ggml_view_4d(ctx0, v,
             S_v, n_tokens, H_v, n_seqs,
-            ggml_row_size(core_attn_out->type, S_v),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
-            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
-    output_tokens = ggml_cont(ctx0, output_tokens);
-    cb(output_tokens, "output_tokens", il);
+            ggml_row_size(v->type, S_v),
+            ggml_row_size(v->type, S_v * CS * n_chunks),
+            ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
 
-    // permute back to (S_v, H_v, n_tokens, n_seqs)
-    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
-    output_tokens = ggml_cont(ctx0, output_tokens);
+    o = ggml_permute  (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
 
-    return {output_tokens, new_state};
+    return {o, s};
 }
 
 std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive(
@@ -360,8 +324,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_aut
         ggml_tensor * k,
         ggml_tensor * v,
         ggml_tensor * g,
-        ggml_tensor * beta,
-        ggml_tensor * state,
+        ggml_tensor * b, // beta
+        ggml_tensor * s, // state
         int           il) {
     const int64_t S_k      = q->ne[0];
     const int64_t H_k      = q->ne[1];
@@ -371,75 +335,72 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_aut
     const int64_t S_v = v->ne[0];
     const int64_t H_v = v->ne[1];
 
-    GGML_ASSERT(n_tokens == 1);  // This function is optimized for single token processing
-    GGML_ASSERT(v->ne[2] == n_tokens);
-    GGML_ASSERT(k->ne[2] == n_tokens);
-    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
-    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
-    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
+    GGML_ASSERT(n_tokens == 1);
+
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
 
     GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
     GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
 
-    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
-
-    const float eps_norm = hparams.f_norm_rms_eps;
+    GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
+    GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
 
-    q = ggml_l2_norm(ctx0, q, eps_norm);
-    k = ggml_l2_norm(ctx0, k, eps_norm);
+    const float scale = 1.0f / sqrtf(S_k);
 
-    const float scale = 1.0f / sqrtf(S_v);
+    q = ggml_scale(ctx0, q, scale);
 
-    q    = ggml_scale(ctx0, q, scale);
-    beta = ggml_sigmoid(ctx0, beta);
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
 
     cb(q, "q_in", il);
     cb(k, "k_in", il);
     cb(v, "v_in", il);
-    cb(beta, "beta_in", il);
+    cb(b, "b_in", il);
     cb(g, "g_in", il);
 
-    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
+    g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
+    b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
 
-    ggml_tensor * g_t    = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
-    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
+    // [S_v, S_v, H_v, n_seqs]
+    g = ggml_exp(ctx0, g);
+    s = ggml_mul(ctx0, s, g);
 
-    // Apply exponential to g_t
-    g_t = ggml_exp(ctx0, g_t);
+    ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
 
-    // Apply the gated delta rule for the single timestep
-    // last_recurrent_state = last_recurrent_state * g_t
-    state = ggml_mul(ctx0, state, g_t);
+    // [1, S_v, H_v, n_seqs]
+    ggml_tensor * sk;
+    sk = ggml_mul     (ctx0, s_t, k);
+    sk = ggml_sum_rows(ctx0, sk);
 
-    // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
-    ggml_tensor * kv_mem         = ggml_mul(ctx0, state, k_t_unsqueezed);
-    // we need to sum over dim=-2, so we transpose, sum, then transpose again
-    kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
+    // [S_v, 1, H_v, n_seqs]
+    ggml_tensor * d;
+    d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk));
+    d = ggml_mul(ctx0, d, b);
 
-    // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
-    ggml_tensor * v_t    = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
-    // delta = (v_t - kv_mem) * beta_t
-    ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem);  // both should be [S_v, 1, H_v, n_seqs]
-    ggml_tensor * delta  = ggml_mul(ctx0, v_diff, beta_t);
+    // [1, S_v, H_v, n_seqs]
+    ggml_tensor * d_t;
+    d_t = ggml_transpose(ctx0, d);
 
-    // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
-    ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
-    state                   = ggml_add(ctx0, state, k_t_delta);
+    // [S_v, S_v, H_v, n_seqs]
+    ggml_tensor * kd;
+    k  = ggml_repeat(ctx0, k, s);
+    kd = ggml_mul   (ctx0, k, d_t);
 
-    // Compute the attention output
-    // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
-    ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs);  // unsqueeze q_t
-    ggml_tensor * state_q        = ggml_mul(ctx0, state, q_t_unsqueezed);
-    // again, since it's over dim = -2, transpose, sum, transpose back
-    ggml_tensor * core_attn_out =
-        ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
+    s_t = ggml_add(ctx0, s_t, kd);
 
-    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
-    cb(core_attn_out, "output_tokens", il);
-    cb(state, "new_state", il);
+    cb(s_t, "dnet_add_ar_state", il);
 
-    return {core_attn_out, state};
+    ggml_tensor * s_q = ggml_mul     (ctx0, s_t, q);
+    ggml_tensor * o   = ggml_sum_rows(ctx0, s_q);
+
+    o = ggml_permute  (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
+    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
+
+    return {o, s};
 }
 
 ggml_tensor * llm_build_qwen3next::build_norm_gated(
@@ -472,39 +433,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
     // Split Q projection into query and gate
     // The split should be along dimension 0 (the feature dimension)
     ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
-                                             Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
+                                            Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
+    cb(Qcur, "Qcur_view", il);
+
     ggml_tensor * gate =
         ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
                      Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full));
-    cb(Qcur, "Qcur", il);
     cb(gate, "gate", il);
 
-    // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
-    Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
-    cb(Qcur, "Qcur_reshaped", il);
-
-    // Apply Q normalization
-    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
-    cb(Qcur, "Qcur_normed", il);
-
     ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
     cb(Kcur, "Kcur", il);
 
     ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
     cb(Vcur, "Vcur", il);
 
-    // Apply K normalization
     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
-    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
-    cb(Kcur, "Kcur_normed", il);
+    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
-    // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
-    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
-    cb(gate, "gate_reshaped", il);
+    Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Qcur, "Qcur_normed", il);
 
-    Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+    Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+    cb(Kcur, "Kcur_normed", il);
 
-    // Apply RoPE
     Qcur = ggml_rope_ext(
             ctx0, Qcur, inp_pos, nullptr,
             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -519,7 +470,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
     cb(Kcur, "Kcur", il);
     cb(Vcur, "Vcur", il);
 
-    // Attention computation
     const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
     cur = build_attn(inp,
@@ -527,10 +477,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
                 Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
     cb(cur, "attn_pregate", il);
 
-    ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate);
-    cb(gate_sigmoid, "gate_sigmoid", il);
+    // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
+    gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
+
+    gate = ggml_sigmoid(ctx0, gate);
+    cb(gate, "gate_sigmoid", il);
+
+    gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
 
-    cur = ggml_mul(ctx0, cur, gate_sigmoid);
+    cur = ggml_mul(ctx0, cur, gate);
     cb(cur, "attn_gated", il);
 
     cur = build_lora_mm(model.layers[il].wo, cur);
@@ -560,7 +515,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
         cb(z, "z", il);
 
         return { qkv_mixed, z };
-
     } else {
         // legacy (slower) path
         ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input);
@@ -624,9 +578,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz(
 ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
         llm_graph_input_rs * inp,
         ggml_tensor *        cur,
-        ggml_tensor *        causal_mask,
-        ggml_tensor *        identity,
-        ggml_tensor *        diag_mask,
         int                  il) {
     const auto * mctx_cur = inp->mctx;
 
@@ -671,7 +622,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
                                    split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped));
     cb(a, "a", il);
 
-    ggml_tensor * beta  = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs);
+    // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont
+    b = ggml_cont(ctx0, b);
+
+    ggml_tensor * beta = ggml_sigmoid(ctx0, b);
+
+    beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
 
     // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
     ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
@@ -679,6 +635,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * alpha_biased   = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
     ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
     cb(alpha_softplus, "a_softplus", il);
+
     ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a);  // -A_log.exp() * softplus
     cb(gate, "gate", il);
 
@@ -686,8 +643,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
     ggml_tensor * ssm_states_all  = mctx_cur->get_s_l(il);
 
-    // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
-
     // Build the convolution states tensor
     ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
     cb(conv_states, "conv_states", il);
@@ -696,11 +651,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_tensor * conv_kernel      = model.layers[il].ssm_conv1d;
     const int64_t conv_kernel_size = conv_kernel->ne[0];
     const int64_t conv_channels    = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
-    conv_states                    = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
+
+    conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
     cb(conv_states, "conv_states_reshaped", il);
 
-    qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
-    cb(qkv_mixed, "qkv_mixed_permuted", il);
+    qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
+    cb(qkv_mixed, "qkv_mixed_transposed", il);
 
     ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
     cb(conv_input, "conv_input", il);
@@ -720,7 +676,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
     cb(conv_states_all, "conv_states_updated", il);
 
-    // Apply SSM convolution
+    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
+    state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
+    cb(state, "state_predelta", il);
+
     ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
     cb(conv_output_proper, "conv_output_raw", il);
 
@@ -734,26 +693,36 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
 
     // Extract the convolved Q, K, V from conv_output
-    ggml_tensor * q_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
+    ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            0);
+
+    ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_k_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
+
+    ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
+            ggml_row_size(conv_qkv_mix->type, head_v_dim),
+            nb1_qkv,
+            nb1_qkv * n_seq_tokens,
+            ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
+
     cb(q_conv, "q_conv", il);
-    ggml_tensor * k_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
-                     head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(k_conv, "k_conv", il);
-    ggml_tensor * v_conv =
-        ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
-                     2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
     cb(v_conv, "v_conv", il);
 
-    // Unsqueeze them
-    q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
-    k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
-    v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
+    const float eps_norm = hparams.f_norm_rms_eps;
 
-    ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
-    state               = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
-    cb(state, "state_predelta", il);
+    q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
+    k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
+
+    //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
+    //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
     if (num_k_heads != num_v_heads) {
@@ -786,7 +755,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     if (n_seq_tokens == 1) {
         attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
     } else {
-        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
+        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
     }
     ggml_tensor * output    = attn_out.first;
     ggml_tensor * new_state = attn_out.second;
@@ -795,19 +764,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
 
     // Update the recurrent states
     ggml_build_forward_expand(gf,
-                              ggml_cpy(ctx0, new_state,
-                                       ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
-                                                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
-
-    // Reshape both attn_out_final and z to 2D tensors for normalization
-    // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+            ggml_cpy(ctx0, new_state,
+                ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                    kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
 
     // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
-    ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
+    ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // Apply gated normalization: self.norm(core_attn_out, z)
-    ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
+    ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
 
     // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
     ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
@@ -818,7 +783,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(cur, "linear_attn_out", il);
 
     // Reshape back to original dimensions
-    cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+    cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
+
     return cur;
 }
 
@@ -839,7 +805,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
         if (model.layers[il].ffn_up_shexp != nullptr) {
             ggml_tensor * ffn_shexp =
                 build_ffn(cur,
-                    model.layers[il].ffn_up_shexp, NULL, NULL,
+                    model.layers[il].ffn_up_shexp,   NULL, NULL,
                     model.layers[il].ffn_gate_shexp, NULL, NULL,
                     model.layers[il].ffn_down_shexp, NULL, NULL,
                     NULL,
@@ -852,11 +818,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
             ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
             cb(shared_gate, "shared_expert_gate", il);
 
-            // Apply sigmoid to the gate
             shared_gate = ggml_sigmoid(ctx0, shared_gate);
             cb(shared_gate, "shared_expert_gate_sigmoid", il);
 
-            // Apply the gate to the shared expert output
             ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
             cb(ffn_shexp, "ffn_shexp_gated", il);