]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
graph : remove redundant GDN state transposes (#20443)
authorGeorgi Gerganov <redacted>
Fri, 13 Mar 2026 20:12:54 +0000 (22:12 +0200)
committerGitHub <redacted>
Fri, 13 Mar 2026 20:12:54 +0000 (22:12 +0200)
* ggml : transpose fused GDN state access for coalesced memory reads (#20436)

The fused Gated Delta Net kernel accessed the [S_v, S_v] state matrix
column-wise on row-major storage, causing strided reads (stride S_v =
128 floats = 512 bytes) that waste GPU cache bandwidth. This produced a
39% regression on Qwen3.5-9B (Metal, M4 Max) compared to the unfused
path.

Transpose the state indexing so threads read contiguously:
- Metal: s_ptr[is*S_v] -> s_ptr[is] (stride 1 vs S_v)
- CUDA:  curr_state[i*S_v+col] -> curr_state[col*S_v+i] (coalesced)
- CPU:   restructured loops for row-wise transposed access

Also add --fused-gdn [on|off|auto] CLI flag (mirrors --flash-attn) so
users can control fused GDN independently of auto-detection.

All GATED_DELTA_NET backend-ops tests pass.

Co-Authored-By: Claude Opus 4.6 <redacted>
* ggml : use SIMD dot products in CPU GDN kernel, couple AR/chunked fused flags

- Replace scalar inner loops with ggml_vec_dot_f32 for SIMD-optimized
  dot products in the CPU fused GDN kernel (delta and attention output)
- Couple fused_gdn_ar and fused_gdn_ch flags in auto-detection: if one
  path lacks device support, disable both to prevent state layout mismatch
  between transposed (fused) and non-transposed (unfused) formats

Co-Authored-By: Claude Opus 4.6 <redacted>
* llama : rever fgdn argument changes

* graph : remove GDN state transposes

* vulkan : adapt

* cuda : remove obsolete smem code

---------

Co-authored-by: Paul Flynn <redacted>
Co-authored-by: Claude Opus 4.6 <redacted>
Co-authored-by: Oliver Simons <redacted>
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cuda/gated_delta_net.cu
ggml/src/ggml-metal/ggml-metal.metal
ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp
src/models/delta-net-base.cpp

index 85db02d92f19238cd861c577d3f1cd1eceb6c63d..314cc1088a0625abb06d6d5310057c20459ad4d2 100644 (file)
@@ -10477,34 +10477,40 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
             const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
             const float * g_d    =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
 
+            // state is stored transposed: s_out[j*S_v + i] = S[i][j]
+            // so row j of s_out = column j of S (contiguous access)
+
             if (kda) {
+                // precompute exp(g) into delta scratch (reused below)
                 for (int64_t i = 0; i < S_v; ++i) {
-                    ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i]));
+                    delta[i] = expf(g_d[i]);
+                }
+                // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
+                for (int64_t j = 0; j < S_v; ++j) {
+                    ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
                 }
             } else {
                 ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
             }
 
-            // delta[j] = sum_i S[j][i] * k[i]
-            memset(delta, 0, S_v * sizeof(float));
-            for (int64_t i = 0; i < S_v; ++i) {
-                ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]);
-            }
+            // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
             for (int64_t j = 0; j < S_v; ++j) {
-                delta[j] = (v_d[j] - delta[j]) * beta_val;
+                float sum = 0.0f;
+                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
+                delta[j] = (v_d[j] - sum) * beta_val;
             }
 
-            // outer product: S[j][i] += k[i] * delta[j]
-            for (int64_t i = 0; i < S_v; ++i) {
-                ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]);
+            // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
+            for (int64_t j = 0; j < S_v; ++j) {
+                ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
             }
 
-            // attn_out[j] = sum_i S[j][i] * q[i]
-            memset(attn_data, 0, S_v * sizeof(float));
-            for (int64_t i = 0; i < S_v; ++i) {
-                ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]);
+            // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
+            for (int64_t j = 0; j < S_v; ++j) {
+                float sum = 0.0f;
+                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
+                attn_data[j] = sum * scale;
             }
-            ggml_vec_scale_f32(S_v, attn_data, scale);
 
             attn_data += S_v * H; // advance to next token
         }
index 5f0fa8e58dfedf48a9f12a4c8b7bf4d6fdaa2d0e..1ce6d5f31b5780774e567b365f89b158997c0a62 100644 (file)
@@ -45,10 +45,11 @@ __global__ void gated_delta_net_cuda(const float * q,
     static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
     constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
     float         s_shard[rows_per_lane];
+    // state is stored transposed: M[col][i] = S[i][col], row col is contiguous
 #pragma unroll
     for (int r = 0; r < rows_per_lane; r++) {
         const int i = r * warp_size + lane;
-        s_shard[r]  = curr_state[i * S_v + col];
+        s_shard[r]  = curr_state[col * S_v + i];
     }
 
     for (int t = 0; t < n_tokens; t++) {
@@ -126,23 +127,14 @@ __global__ void gated_delta_net_cuda(const float * q,
         attn_data += S_v * H;
     }
 
-    // Write state back to global memory
+    // Write state back to global memory (transposed layout)
 #pragma unroll
     for (int r = 0; r < rows_per_lane; r++) {
         const int i          = r * warp_size + lane;
-        state[i * S_v + col] = s_shard[r];
+        state[col * S_v + i] = s_shard[r];
     }
 }
 
-static size_t calculate_smem(const int sv, int cc)
-{
-    size_t smem = 0;
-    if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
-        smem = sv * sv * sizeof(float);
-    }
-    return smem;
-}
-
 template <bool KDA>
 static void launch_gated_delta_net(
         const float * q_d, const float * k_d, const float * v_d,
@@ -179,18 +171,14 @@ static void launch_gated_delta_net(
                 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
         case 64: {
-            constexpr int sv = 64;
-            size_t smem = calculate_smem(sv, cc);
-            gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+            gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
                 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
         }
         case 128: {
-            constexpr int sv = 128;
-            size_t smem = calculate_smem(sv, cc);
-            gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+            gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
                 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
index 107e7cf2ff36bc18ac69a4cd8dede47621b21453..d4b129ed756dc684616d9f67e2120563bc7adff6 100644 (file)
@@ -2469,13 +2469,14 @@ kernel void kernel_gated_delta_net_impl(
 
     const float scale = 1.0f / sqrt((float)S_v);
 
-    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+    // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
+    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
 
     float ls[NSG];
 
     FOR_UNROLL (short j = 0; j < NSG; j++) {
         const short is = tx*NSG + j;
-        ls[j] = s_ptr[is*S_v];
+        ls[j] = s_ptr[is];
     }
 
     device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
@@ -2536,11 +2537,11 @@ kernel void kernel_gated_delta_net_impl(
         g_ptr += args.ne21*G;
     }
 
-    device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+    device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
 
     FOR_UNROLL (short j = 0; j < NSG; j++) {
         const short is = tx*NSG + j;
-        dst_state[is*S_v] = ls[j];
+        dst_state[is] = ls[j];
     }
 
 #undef S_v
index 1fdf889e824684e0ad994f973a469ee4dfa56fc4..f008859b99da67d41230116dcc9db6848bc2d371 100644 (file)
@@ -44,7 +44,7 @@ void main() {
 
     FLOAT_TYPE state[S_V];
     [[unroll]] for (uint i = 0; i < S_V; i++) {
-        state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
+        state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
     }
 
     uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
@@ -123,6 +123,6 @@ void main() {
     }
 
     [[unroll]] for (uint i = 0; i < S_V; i++) {
-        data_dst[s_off + state_base + i * S_V + col] = state[i];
+        data_dst[s_off + state_base + col * S_V + i] = state[i];
     }
 }
index a62dbc15dd0558a914cb8f36e8e751303d680f22..6bc989c950992a4d43591b5d3f12016947b2ba53 100644 (file)
@@ -225,9 +225,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
     cb(kg_t, "key_gdiff_t", il);
 
-    ggml_tensor * s_t = ggml_transpose(ctx0, s);
-    s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
-    cb(s_t, "dnet_add_ch_state", il);
+    s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs);
+    cb(s, "dnet_add_ch_state", il);
 
     // [CS, S_v, n_chunks, H_v * n_seqs]
     ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
@@ -240,7 +239,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
         ggml_tensor * ch_kg_t    = get_slice_2d(ctx0, kg_t,    chunk); // [ CS, S_k, 1, H_v * n_seqs]
 
         // [CS, S_v, 1, H_v * n_seqs]
-        ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
+        ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s);
         cb(v_t_p, "v_prime", il);
 
         // [CS, S_v, 1, H_v * n_seqs]
@@ -252,7 +251,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
         cb(v_attn, "v_attn", il);
 
         // [S_v, CS, 1, H_v * n_seqs]
-        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp);
         cb(attn_inter, "attn_inter", il);
 
         // [S_v, CS, 1, H_v * n_seqs]
@@ -268,13 +267,11 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
         // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
         ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
 
-        s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
-        s_t = ggml_add(ctx0, s_t, kgv);
-        cb(s_t, "dnet_add_ch_state", il);
+        s = ggml_mul(ctx0, s, ch_g_last_exp_t);
+        s = ggml_add(ctx0, s, kgv);
+        cb(s, "dnet_add_ch_state", il);
     }
 
-    s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
-
     // truncate padded tokens
     ggml_tensor * o = ggml_view_4d(ctx0, v,
             S_v, n_tokens, H_v, n_seqs,
@@ -282,7 +279,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
             ggml_row_size(v->type, S_v * CS * n_chunks),
             ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
     o = ggml_permute  (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
-    s = ggml_transpose(ctx0, s_t);
+    s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs);
     cb(s, "output_state", il);
 
     return {o, s};
@@ -341,11 +338,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     g = ggml_exp(ctx0, g);
     s = ggml_mul(ctx0, s, g);
 
-    ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
-
     // [1, S_v, H_v, n_seqs]
     ggml_tensor * sk;
-    sk = ggml_mul     (ctx0, s_t, k);
+    sk = ggml_mul     (ctx0, s, k);
     sk = ggml_sum_rows(ctx0, sk);
 
     // [S_v, 1, H_v, n_seqs]
@@ -362,15 +357,14 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     k  = ggml_repeat(ctx0, k, s);
     kd = ggml_mul   (ctx0, k, d_t);
 
-    s_t = ggml_add(ctx0, s_t, kd);
+    s = ggml_add(ctx0, s, kd);
 
-    cb(s_t, "dnet_add_ar_state", il);
+    cb(s, "dnet_add_ar_state", il);
 
-    ggml_tensor * s_q = ggml_mul     (ctx0, s_t, q);
+    ggml_tensor * s_q = ggml_mul     (ctx0, s, q);
     ggml_tensor * o   = ggml_sum_rows(ctx0, s_q);
 
     o = ggml_permute  (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
-    s = ggml_transpose(ctx0, s_t);           // [S_v, S_v, H_v, n_seqs]
 
     return {o, s};
 }