]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
graph : remove redundant GDN state transposes (llama/20443)
authorGeorgi Gerganov <redacted>
Fri, 13 Mar 2026 20:12:54 +0000 (22:12 +0200)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* ggml : transpose fused GDN state access for coalesced memory reads (llama/20436)

The fused Gated Delta Net kernel accessed the [S_v, S_v] state matrix
column-wise on row-major storage, causing strided reads (stride S_v =
128 floats = 512 bytes) that waste GPU cache bandwidth. This produced a
39% regression on Qwen3.5-9B (Metal, M4 Max) compared to the unfused
path.

Transpose the state indexing so threads read contiguously:
- Metal: s_ptr[is*S_v] -> s_ptr[is] (stride 1 vs S_v)
- CUDA:  curr_state[i*S_v+col] -> curr_state[col*S_v+i] (coalesced)
- CPU:   restructured loops for row-wise transposed access

Also add --fused-gdn [on|off|auto] CLI flag (mirrors --flash-attn) so
users can control fused GDN independently of auto-detection.

All GATED_DELTA_NET backend-ops tests pass.

Co-Authored-By: Claude Opus 4.6 <redacted>
* ggml : use SIMD dot products in CPU GDN kernel, couple AR/chunked fused flags

- Replace scalar inner loops with ggml_vec_dot_f32 for SIMD-optimized
  dot products in the CPU fused GDN kernel (delta and attention output)
- Couple fused_gdn_ar and fused_gdn_ch flags in auto-detection: if one
  path lacks device support, disable both to prevent state layout mismatch
  between transposed (fused) and non-transposed (unfused) formats

Co-Authored-By: Claude Opus 4.6 <redacted>
* llama : rever fgdn argument changes

* graph : remove GDN state transposes

* vulkan : adapt

* cuda : remove obsolete smem code

---------

Co-authored-by: Paul Flynn <redacted>
Co-authored-by: Claude Opus 4.6 <redacted>
Co-authored-by: Oliver Simons <redacted>
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cuda/gated_delta_net.cu
ggml/src/ggml-metal/ggml-metal.metal
ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp

index 85db02d92f19238cd861c577d3f1cd1eceb6c63d..314cc1088a0625abb06d6d5310057c20459ad4d2 100644 (file)
@@ -10477,34 +10477,40 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
             const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
             const float * g_d    =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
 
+            // state is stored transposed: s_out[j*S_v + i] = S[i][j]
+            // so row j of s_out = column j of S (contiguous access)
+
             if (kda) {
+                // precompute exp(g) into delta scratch (reused below)
                 for (int64_t i = 0; i < S_v; ++i) {
-                    ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i]));
+                    delta[i] = expf(g_d[i]);
+                }
+                // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
+                for (int64_t j = 0; j < S_v; ++j) {
+                    ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
                 }
             } else {
                 ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
             }
 
-            // delta[j] = sum_i S[j][i] * k[i]
-            memset(delta, 0, S_v * sizeof(float));
-            for (int64_t i = 0; i < S_v; ++i) {
-                ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]);
-            }
+            // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
             for (int64_t j = 0; j < S_v; ++j) {
-                delta[j] = (v_d[j] - delta[j]) * beta_val;
+                float sum = 0.0f;
+                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
+                delta[j] = (v_d[j] - sum) * beta_val;
             }
 
-            // outer product: S[j][i] += k[i] * delta[j]
-            for (int64_t i = 0; i < S_v; ++i) {
-                ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]);
+            // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
+            for (int64_t j = 0; j < S_v; ++j) {
+                ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
             }
 
-            // attn_out[j] = sum_i S[j][i] * q[i]
-            memset(attn_data, 0, S_v * sizeof(float));
-            for (int64_t i = 0; i < S_v; ++i) {
-                ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]);
+            // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
+            for (int64_t j = 0; j < S_v; ++j) {
+                float sum = 0.0f;
+                ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
+                attn_data[j] = sum * scale;
             }
-            ggml_vec_scale_f32(S_v, attn_data, scale);
 
             attn_data += S_v * H; // advance to next token
         }
index 5f0fa8e58dfedf48a9f12a4c8b7bf4d6fdaa2d0e..1ce6d5f31b5780774e567b365f89b158997c0a62 100644 (file)
@@ -45,10 +45,11 @@ __global__ void gated_delta_net_cuda(const float * q,
     static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
     constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
     float         s_shard[rows_per_lane];
+    // state is stored transposed: M[col][i] = S[i][col], row col is contiguous
 #pragma unroll
     for (int r = 0; r < rows_per_lane; r++) {
         const int i = r * warp_size + lane;
-        s_shard[r]  = curr_state[i * S_v + col];
+        s_shard[r]  = curr_state[col * S_v + i];
     }
 
     for (int t = 0; t < n_tokens; t++) {
@@ -126,23 +127,14 @@ __global__ void gated_delta_net_cuda(const float * q,
         attn_data += S_v * H;
     }
 
-    // Write state back to global memory
+    // Write state back to global memory (transposed layout)
 #pragma unroll
     for (int r = 0; r < rows_per_lane; r++) {
         const int i          = r * warp_size + lane;
-        state[i * S_v + col] = s_shard[r];
+        state[col * S_v + i] = s_shard[r];
     }
 }
 
-static size_t calculate_smem(const int sv, int cc)
-{
-    size_t smem = 0;
-    if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
-        smem = sv * sv * sizeof(float);
-    }
-    return smem;
-}
-
 template <bool KDA>
 static void launch_gated_delta_net(
         const float * q_d, const float * k_d, const float * v_d,
@@ -179,18 +171,14 @@ static void launch_gated_delta_net(
                 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
         case 64: {
-            constexpr int sv = 64;
-            size_t smem = calculate_smem(sv, cc);
-            gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+            gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
                 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
         }
         case 128: {
-            constexpr int sv = 128;
-            size_t smem = calculate_smem(sv, cc);
-            gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+            gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
                 sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
index 107e7cf2ff36bc18ac69a4cd8dede47621b21453..d4b129ed756dc684616d9f67e2120563bc7adff6 100644 (file)
@@ -2469,13 +2469,14 @@ kernel void kernel_gated_delta_net_impl(
 
     const float scale = 1.0f / sqrt((float)S_v);
 
-    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+    // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
+    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
 
     float ls[NSG];
 
     FOR_UNROLL (short j = 0; j < NSG; j++) {
         const short is = tx*NSG + j;
-        ls[j] = s_ptr[is*S_v];
+        ls[j] = s_ptr[is];
     }
 
     device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
@@ -2536,11 +2537,11 @@ kernel void kernel_gated_delta_net_impl(
         g_ptr += args.ne21*G;
     }
 
-    device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+    device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
 
     FOR_UNROLL (short j = 0; j < NSG; j++) {
         const short is = tx*NSG + j;
-        dst_state[is*S_v] = ls[j];
+        dst_state[is] = ls[j];
     }
 
 #undef S_v
index 1fdf889e824684e0ad994f973a469ee4dfa56fc4..f008859b99da67d41230116dcc9db6848bc2d371 100644 (file)
@@ -44,7 +44,7 @@ void main() {
 
     FLOAT_TYPE state[S_V];
     [[unroll]] for (uint i = 0; i < S_V; i++) {
-        state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
+        state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
     }
 
     uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
@@ -123,6 +123,6 @@ void main() {
     }
 
     [[unroll]] for (uint i = 0; i < S_V; i++) {
-        data_dst[s_off + state_base + i * S_V + col] = state[i];
+        data_dst[s_off + state_base + col * S_V + i] = state[i];
     }
 }