]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : enable chunked fused GDN path (#20340)
authorGeorgi Gerganov <redacted>
Wed, 11 Mar 2026 20:46:40 +0000 (22:46 +0200)
committerGitHub <redacted>
Wed, 11 Mar 2026 20:46:40 +0000 (22:46 +0200)
* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <redacted>
* cont : fix the fix

Co-authored-by: Aman Gupta <redacted>
* cont : fix

* metal : add GDN kernel (#20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <redacted>
* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <redacted>
* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <redacted>
* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <redacted>
Co-authored-by: Claude Opus 4.6 <redacted>
* CUDA: AR gated delta net improvements (#20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 206890897546bd16602c3b79394fd5ea09ef199f

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <redacted>
* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <redacted>
Co-authored-by: Paul Flynn <redacted>
Co-authored-by: Claude Opus 4.6 <redacted>
Co-authored-by: Oliver Simons <redacted>
Co-authored-by: uvos <redacted>
20 files changed:
ggml/include/ggml.h
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cuda/gated_delta_net.cu
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal-ops.h
ggml/src/ggml-metal/ggml-metal.metal
src/llama-context.cpp
src/llama-cparams.h
src/llama-impl.h
src/models/delta-net-base.cpp
src/models/kimi-linear.cpp
src/models/models.h
src/models/qwen35.cpp
src/models/qwen35moe.cpp
src/models/qwen3next.cpp
tests/test-backend-ops.cpp

index 3323f8e6c3fea805022bd44e9f11bd38da2cdb68..25f9601e9b5e74af84da0d5b1058633e01b8fa38 100644 (file)
@@ -2466,6 +2466,8 @@ extern "C" {
         bool                  lower,
         bool                  uni);
 
+    // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
+    // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
     GGML_API struct ggml_tensor * ggml_gated_delta_net(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
index f9c4ec16e4b78c5e1cca8915ecc0cd8ae66761c3..fa9d27046b5ea0101c827fce68cec60218b50804 100644 (file)
@@ -10443,8 +10443,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
 
     const float * state_in_base = (const float *)src_state->data;
 
-    const int64_t rq1 = nev1 / neq1;
-    const int64_t rk1 = nev1 / nek1;
+  //const int64_t rq1 = nev1 / neq1;
+  //const int64_t rk1 = nev1 / nek1;
     const int64_t rq3 = nev3 / neq3;
     const int64_t rk3 = nev3 / nek3;
 
@@ -10454,8 +10454,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
         const int64_t iv1 = ir % H; // head_index
         const int64_t iv3 = ir / H; // sequence
 
-        const int64_t iq1 = iv1 / rq1;
-        const int64_t ik1 = iv1 / rk1;
+        const int64_t iq1 = iv1 % neq1;
+        const int64_t ik1 = iv1 % nek1;
 
         const int64_t iq3 = iv3 / rq3;
         const int64_t ik3 = iv3 / rk3;
@@ -10475,7 +10475,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
             const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
 
             const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
-            const float * g_d   =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
+            const float * g_d    =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
 
             if (kda) {
                 for (int64_t i = 0; i < S_v; ++i) {
@@ -10508,7 +10508,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
 
             attn_data += S_v * H; // advance to next token
         }
-
     }
 }
 
index c249bbc86d5a8f83ca11dbba98740de971a327a9..5f0fa8e58dfedf48a9f12a4c8b7bf4d6fdaa2d0e 100644 (file)
@@ -1,36 +1,36 @@
 #include "gated_delta_net.cuh"
-#include "ggml-cuda/common.cuh"
 
 template <int S_v, bool KDA>
-__global__ void __launch_bounds__(S_v, 1)
-gated_delta_net_cuda(const float * q,
-                     const float * k,
-                     const float * v,
-                     const float * g,
-                     const float * beta,
-                     const float * curr_state,
-                     float *       dst,
-                     const int64_t H,
-                     const int64_t n_tokens,
-                     const int64_t n_seqs,
-                     const int64_t sq1,
-                     const int64_t sq2,
-                     const int64_t sq3,
-                     const int64_t sv1,
-                     const int64_t sv2,
-                     const int64_t sv3,
-                     const int64_t sb1,
-                     const int64_t sb2,
-                     const int64_t sb3,
-                     const int64_t rq1,
-                     const int64_t rq3,
-                     const float   scale) {
-    const int64_t h_idx    = blockIdx.x;
-    const int64_t sequence = blockIdx.y;
-    const int     col      = threadIdx.x;  // each thread owns one column
-
-    const int64_t iq1 = h_idx / rq1;
-    const int64_t iq3 = sequence / rq3;
+__global__ void gated_delta_net_cuda(const float * q,
+                                     const float * k,
+                                     const float * v,
+                                     const float * g,
+                                     const float * beta,
+                                     const float * curr_state,
+                                     float *       dst,
+                                     int64_t       H,
+                                     int64_t       n_tokens,
+                                     int64_t       n_seqs,
+                                     int64_t       sq1,
+                                     int64_t       sq2,
+                                     int64_t       sq3,
+                                     int64_t       sv1,
+                                     int64_t       sv2,
+                                     int64_t       sv3,
+                                     int64_t       sb1,
+                                     int64_t       sb2,
+                                     int64_t       sb3,
+                                     const uint3   neqk1_magic,
+                                     const uint3   rq3_magic,
+                                     float         scale) {
+    const uint32_t h_idx    = blockIdx.x;
+    const uint32_t sequence = blockIdx.y;
+    // each warp owns one column, using warp-level primitives to reduce across rows
+    const int      lane     = threadIdx.x;
+    const int      col      = blockIdx.z * blockDim.y + threadIdx.y;
+
+    const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
+    const uint32_t iq3 = fastdiv(sequence, rq3_magic);
 
     const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
     float *       attn_data        = dst;
@@ -41,17 +41,14 @@ gated_delta_net_cuda(const float * q,
     curr_state += state_offset;
     attn_data += (sequence * n_tokens * H + h_idx) * S_v;
 
-    // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
-    // TODO: check optimal path for RDNA1 and RDNA2 devices.
-#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
-    extern __shared__ float s_shared[];
-    float * s = s_shared + col * S_v;
-#else
-    float s[S_v];
-#endif
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
+    static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
+    constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
+    float         s_shard[rows_per_lane];
 #pragma unroll
-    for (int i = 0; i < S_v; i++) {
-        s[i] = curr_state[i * S_v + col];
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i = r * warp_size + lane;
+        s_shard[r]  = curr_state[i * S_v + col];
     }
 
     for (int t = 0; t < n_tokens; t++) {
@@ -69,46 +66,61 @@ gated_delta_net_cuda(const float * q,
             const float g_val = expf(*g_t);
 
             // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
-            float kv_col = 0.0f;
+            float kv_shard = 0.0f;
 #pragma unroll
-            for (int i = 0; i < S_v; i++) {
-                kv_col += s[i] * k_t[i];
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += s_shard[r] * k_t[i];
             }
+            float kv_col = warp_reduce_sum<warp_size>(kv_shard);
 
             // delta[col] = (v[col] - g * kv[col]) * beta
             float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
 
             // fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
             // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
-            float attn_col = 0.0f;
+            float attn_partial = 0.0f;
 #pragma unroll
-            for (int i = 0; i < S_v; i++) {
-                s[i] = g_val * s[i] + k_t[i] * delta_col;
-                attn_col += s[i] * q_t[i];
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = g_val * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
             }
 
-            attn_data[col] = attn_col * scale;
+            float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
         } else {
             // kv[col] = sum_i g[i] * S[i][col] * k[i]
-            float kv_col = 0.0f;
+            float kv_shard = 0.0f;
 #pragma unroll
-            for (int i = 0; i < S_v; i++) {
-                kv_col += expf(g_t[i]) * s[i] * k_t[i];
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
             }
 
+            float kv_col = warp_reduce_sum<warp_size>(kv_shard);
+
             // delta[col] = (v[col] - kv[col]) * beta
             float delta_col = (v_t[col] - kv_col) * beta_val;
 
             // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
             // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
-            float attn_col = 0.0f;
+            float attn_partial = 0.0f;
 #pragma unroll
-            for (int i = 0; i < S_v; i++) {
-                s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
-                attn_col += s[i] * q_t[i];
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
             }
 
-            attn_data[col] = attn_col * scale;
+            float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
         }
 
         attn_data += S_v * H;
@@ -116,8 +128,9 @@ gated_delta_net_cuda(const float * q,
 
     // Write state back to global memory
 #pragma unroll
-    for (int i = 0; i < S_v; i++) {
-        state[i * S_v + col] = s[i];
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i          = r * warp_size + lane;
+        state[i * S_v + col] = s_shard[r];
     }
 }
 
@@ -135,35 +148,43 @@ static void launch_gated_delta_net(
         const float * q_d, const float * k_d, const float * v_d,
         const float * g_d, const float * b_d, const float * s_d,
         float * dst_d,
-        int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
-        int64_t sq1, int64_t sq2, int64_t sq3,
-        int64_t sv1, int64_t sv2, int64_t sv3,
-        int64_t sb1, int64_t sb2, int64_t sb3,
-        int64_t rq1, int64_t rq3,
+        int64_t S_v,   int64_t H, int64_t n_tokens, int64_t n_seqs,
+        int64_t sq1,   int64_t sq2, int64_t sq3,
+        int64_t sv1,   int64_t sv2, int64_t sv3,
+        int64_t sb1,   int64_t sb2, int64_t sb3,
+        int64_t neqk1, int64_t rq3,
         float scale, cudaStream_t stream) {
+    //TODO: Add chunked kernel for even faster pre-fill
+    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
+    const int num_warps = 4;
+    dim3      grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
+    dim3      block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
 
-    dim3 grid_dims(H, n_seqs, 1);
-    dim3 block_dims(S_v, 1, 1);
+    const uint3 neqk1_magic = init_fastdiv_values(neqk1);
+    const uint3 rq3_magic   = init_fastdiv_values(rq3);
 
     int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
     switch (S_v) {
-        case 32: {
-            constexpr int sv = 32;
-            size_t smem = calculate_smem(sv, cc);
-            gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+        case 16:
+            gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-                sb1, sb2, sb3, rq1, rq3, scale);
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+            break;
+        case 32:
+            gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
+                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
-        }
         case 64: {
             constexpr int sv = 64;
             size_t smem = calculate_smem(sv, cc);
             gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-                sb1, sb2, sb3, rq1, rq3, scale);
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
         }
         case 128: {
@@ -172,7 +193,7 @@ static void launch_gated_delta_net(
             gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-                sb1, sb2, sb3, rq1, rq3, scale);
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
         }
         default:
@@ -190,10 +211,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
     ggml_tensor * src_state = dst->src[5];
 
     GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
-    GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
+    GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
     GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
-    GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
-    GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
+    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);
+    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);
 
     const int64_t S_v      = nev0;
     const int64_t H        = nev1;
@@ -202,7 +225,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
 
     const bool kda = (src_g->ne[0] == S_v);
 
-    const int64_t rq1 = nev1 / neq1;
+    GGML_ASSERT(neq1 == nek1);
+    const int64_t neqk1 = neq1;
+
     const int64_t rq3 = nev3 / neq3;
 
     const float * q_d = (const float *) src_q->data;
@@ -241,10 +266,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
     if (kda) {
         launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
             S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-            sb1, sb2, sb3, rq1, rq3, scale, stream);
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
     } else {
         launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
             S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-            sb1, sb2, sb3, rq1, rq3, scale, stream);
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
     }
 }
index 169c63dd7a457bb4ccdbd421669ac9c49434d8f8..15ae2e517df75ac293524bac10a452c35c8f9f82 100644 (file)
@@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
+    char base[256];
+    char name[256];
+
+    // v is src[2], dimensions: S_v = ne[0], H = ne[1]
+    const int ne20 = op->src[2]->ne[0]; // S_v
+    const int ne21 = op->src[2]->ne[1]; // H
+    const int ne30 = op->src[3]->ne[0]; // G
+
+    const int nsg = op->src[2]->ne[0]/32;
+
+    GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->ne[0] == ne20 * ne21);
+    GGML_ASSERT(ne20 % 32 == 0);
+
+    snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
+    snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
+        ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    res.nsg = nsg;
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
     char base[256];
     char name[256];
index 93d7f6a216f976bf7f251ffde0add5dbf338e2a2..fd2b3ddeb5589f3d33a36d3f81f979a1558fcff4 100644 (file)
@@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched  (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net   (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri         (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, const struct ggml_tensor * op);
index d42b8ab1eb1874a2d0deece92697f7c8d5d36b21..05b826a61b833f217d00b8f0ac803fe3226305d6 100644 (file)
@@ -1155,6 +1155,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
             return true;
+        case GGML_OP_GATED_DELTA_NET:
+            return op->src[2]->ne[0] % 32 == 0;
         case GGML_OP_SOLVE_TRI:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
index 99d64efc3b54319aeebd7c80386f728155a2858d..53437b23cda7bf818bccf70e8fa015d2fa653309 100644 (file)
@@ -84,6 +84,7 @@
 #define FC_BIN                         1300
 #define FC_SUM_ROWS                    1400
 #define FC_UPSCALE                     1500
+#define FC_GATED_DELTA_NET             1600
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -793,6 +794,44 @@ typedef struct {
     uint64_t nb0;
 } ggml_metal_kargs_ssm_scan;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne20;
+    int32_t  ne21;
+    int32_t  ne22;
+    int32_t  ne23;
+    uint64_t nb20;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb23;
+    int32_t  ns02;
+    int32_t  ns12;
+    int32_t  ns22;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_gated_delta_net;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
index 267755d08ccf59277f90269d67f91ad2dea9d7cf..306dbcf36609d303bd7fb4652c62dfad148e54e3 100644 (file)
@@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_rwkv(ctx, idx);
             } break;
+        case GGML_OP_GATED_DELTA_NET:
+            {
+                n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
+            } break;
         case GGML_OP_SOLVE_TRI:
             {
                 n_fuse = ggml_metal_op_solve_tri(ctx, idx);
@@ -1562,6 +1566,81 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
+
+    int ida = 0;
+
+    ggml_metal_kargs_gated_delta_net args = {
+        /*.ne00 =*/ ne00,
+        /*.ne01 =*/ ne01,
+        /*.ne02 =*/ ne02,
+        /*.ne03 =*/ ne03,
+        /*.nb00 =*/ nb00,
+        /*.nb01 =*/ nb01,
+        /*.nb02 =*/ nb02,
+        /*.nb03 =*/ nb03,
+        /*.ne10 =*/ ne10,
+        /*.ne11 =*/ ne11,
+        /*.ne12 =*/ ne12,
+        /*.ne13 =*/ ne13,
+        /*.nb10 =*/ nb10,
+        /*.nb11 =*/ nb11,
+        /*.nb12 =*/ nb12,
+        /*.nb13 =*/ nb13,
+        /*.ne20 =*/ ne20,
+        /*.ne21 =*/ ne21,
+        /*.ne22 =*/ ne22,
+        /*.ne23 =*/ ne23,
+        /*.nb20 =*/ nb20,
+        /*.nb21 =*/ nb21,
+        /*.nb22 =*/ nb22,
+        /*.nb23 =*/ nb23,
+        /*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
+        /*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
+        /*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
+        /*.ne0  =*/ ne0,
+        /*.ne1  =*/ ne1,
+        /*.ne2  =*/ ne2,
+        /*.ne3  =*/ ne3,
+        /*.nb0  =*/ nb0,
+        /*.nb1  =*/ nb1,
+        /*.nb2  =*/ nb2,
+        /*.nb3  =*/ nb3,
+    };
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args),                  ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         ida++); // dst
+
+    const int nsg = pipeline.nsg;
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index f3e38c7aa9de287316a7c429bb17a718e212e11c..019f2fec9edded7c9a74d35f27e9d5b94603e662 100644 (file)
@@ -58,6 +58,7 @@ int ggml_metal_op_soft_max          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rwkv              (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_gated_delta_net   (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_set               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cpy               (ggml_metal_op_t ctx, int idx);
index 29e4a245d5d6aee537d2fee9f807706df98c02bd..0b77d5349b861f03101c766d6f5490ca0ecc24a4 100644 (file)
@@ -2434,6 +2434,227 @@ kernel void kernel_rwkv_wkv7_f32(
     }
 }
 
+constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
+constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
+
+#if 1
+template<short NSG>
+kernel void kernel_gated_delta_net_impl(
+        constant ggml_metal_kargs_gated_delta_net & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * g,
+        device const char * b,
+        device const char * s,
+        device       char * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]])  {
+#define S_v FC_gated_delta_net_ne20
+#define G   FC_gated_delta_net_ne30
+
+    const uint tx = tpitg.x;
+    const uint ty = tpitg.y;
+
+    const uint i23 = tgpig.z; // B
+    const uint i21 = tgpig.y; // H
+    const uint i20 = tgpig.x*NSG + ty;
+
+    const uint i01 = i21 % args.ne01;
+    const uint i11 = i21 % args.ne11;
+
+    const float scale = 1.0f / sqrt((float)S_v);
+
+    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+    float ls[NSG];
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        ls[j] = s_ptr[is*S_v];
+    }
+
+    device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
+
+    device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
+    device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
+    device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
+
+    device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
+    device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
+
+    for (short t = 0; t < args.ne22; t++) {
+        float s_k = 0.0f;
+
+        if (G == 1) {
+            const float g_exp = exp(g_ptr[0]);
+
+            FOR_UNROLL (short j = 0; j < NSG; j++) {
+                const short is = tx*NSG + j;
+                ls[j] *= g_exp;
+
+                s_k += ls[j]*k_ptr[is];
+            }
+        } else {
+            // KDA
+            FOR_UNROLL (short j = 0; j < NSG; j++) {
+                const short is = tx*NSG + j;
+                ls[j] *= exp(g_ptr[is]);
+
+                s_k += ls[j]*k_ptr[is];
+            }
+        }
+
+        s_k = simd_sum(s_k);
+
+        const float d = (v_ptr[i20] - s_k)*b_ptr[0];
+
+        float y = 0.0f;
+
+        FOR_UNROLL (short j = 0; j < NSG; j++) {
+            const short is = tx*NSG + j;
+            ls[j] += k_ptr[is]*d;
+
+            y += ls[j]*q_ptr[is];
+        }
+
+        y = simd_sum(y);
+
+        if (tx == 0) {
+            dst_attn[t*args.ne21*S_v] = y*scale;
+        }
+
+        q_ptr += args.ns02;
+        k_ptr += args.ns12;
+        v_ptr += args.ns22;
+
+        b_ptr += args.ne21;
+        g_ptr += args.ne21*G;
+    }
+
+    device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        dst_state[is*S_v] = ls[j];
+    }
+
+#undef S_v
+#undef G
+}
+
+typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
+
+template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
+template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
+template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
+
+#else
+// a simplified version of the above
+// no performance improvement, so keep the above version for now
+
+template<typename T, short NSG>
+kernel void kernel_gated_delta_net_impl(
+        constant ggml_metal_kargs_gated_delta_net & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * g,
+        device const char * b,
+        device const char * s,
+        device       char * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]])  {
+#define S_v FC_gated_delta_net_ne20
+#define G   FC_gated_delta_net_ne30
+
+    const uint tx = tpitg.x;
+    const uint ty = tpitg.y;
+
+    const uint i23 = tgpig.z; // B
+    const uint i21 = tgpig.y; // H
+    const uint i20 = tgpig.x*NSG + ty;
+
+    const uint i01 = i21 % args.ne01;
+    const uint i11 = i21 % args.ne11;
+
+    const float scale = 1.0f / sqrt((float)S_v);
+
+    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+    float lsf[NSG];
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        lsf[j] = s_ptr[is*S_v];
+    }
+
+    thread T * ls = (thread T *) (lsf);
+
+    device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
+
+    device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
+    device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
+    device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
+
+    device const float * b_ptr  = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
+    device const float * g_ptr  = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
+
+    for (short t = 0; t < args.ne22; t++) {
+        device const T * qt_ptr = (device const T *) (q_ptr);
+        device const T * kt_ptr = (device const T *) (k_ptr);
+        device const T * gt_ptr = (device const T *) (g_ptr);
+
+        if (G == 1) {
+            *ls *= exp(g_ptr[0]);
+        } else {
+            // KDA
+            *ls *= exp(gt_ptr[tx]);
+        }
+
+        const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
+
+        const float d = (v_ptr[i20] - s_k)*b_ptr[0];
+
+        *ls += kt_ptr[tx]*d;
+
+        const float y = simd_sum(dot(*ls, qt_ptr[tx]));
+
+        if (tx == 0) {
+            *dst_attn = y*scale;
+        }
+
+        q_ptr += args.ns02;
+        k_ptr += args.ns12;
+        v_ptr += args.ns22;
+
+        b_ptr += args.ne21;
+        g_ptr += args.ne21*G;
+
+        dst_attn += args.ne21*S_v;
+    }
+
+    device float * dst_state  = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+    device T     * dstt_state = (device T     *) (dst_state);
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        dst_state[is*S_v] = lsf[j];
+    }
+
+#undef S_v
+#undef G
+}
+
+typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
+
+template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float,  1>;
+template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
+template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
+#endif
+
 constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
 constant short FC_solve_tri_n   [[function_constant(FC_SOLVE_TRI + 1)]];
 constant short FC_solve_tri_k   [[function_constant(FC_SOLVE_TRI + 2)]];
index ee2669c154ec83462880673aeb734102a0cbbab2..0be9493910253c0255e40cbbe50fe683e7067237 100644 (file)
@@ -151,7 +151,8 @@ llama_context::llama_context(
     cparams.auto_fa    = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO;
 
     cparams.fused_gdn_ar = true;
-    cparams.fused_gdn_ch = false; // TODO: implement
+    cparams.fused_gdn_ch = true;
+    cparams.auto_fgdn    = true;
 
     // with causal attention, the batch size is limited by the context size
     cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
@@ -462,37 +463,81 @@ void llama_context::sched_reserve() {
         cparams.auto_fa = false;
     }
 
-    if (cparams.fused_gdn_ar) {
-        auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
-        if (!gf) {
-            throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check");
-        }
+    if (cparams.auto_fgdn) {
+        LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__);
 
-        const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1;
-        bool gdn_device_mismatch = false;
-        for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
-            ggml_tensor * n = ggml_graph_node(gf, i);
-            if (n->op != GGML_OP_GATED_DELTA_NET) {
-                continue;
+        if (cparams.fused_gdn_ar) {
+            auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true);
+            if (!gf) {
+                throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)");
             }
-            ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
 
-            GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0);
-            const int il = std::stoi(n->name + prefix_len);
-            ggml_backend_dev_t device_kv = model.dev_layer(il);
-            if (device_gdn != device_kv) {
-                LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
-                        "is assigned to device %s (usually due to missing support)\n",
-                        __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
-                gdn_device_mismatch = true;
-                break;
+            const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1;
+            bool gdn_device_mismatch = false;
+            for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+                ggml_tensor * n = ggml_graph_node(gf, i);
+                if (n->op != GGML_OP_GATED_DELTA_NET) {
+                    continue;
+                }
+                ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
+
+                GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0);
+                const int il = std::stoi(n->name + prefix_len);
+                ggml_backend_dev_t device_kv = model.dev_layer(il);
+                if (device_gdn != device_kv) {
+                    LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
+                            "is assigned to device %s (usually due to missing support)\n",
+                            __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
+                    gdn_device_mismatch = true;
+                    break;
+                }
+            }
+
+            if (gdn_device_mismatch) {
+                cparams.fused_gdn_ar = false;
+                LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__);
+            } else {
+                LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__);
             }
         }
 
-        if (gdn_device_mismatch) {
-            cparams.fused_gdn_ar = false;
-            LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__);
+        if (cparams.fused_gdn_ch) {
+            // more than one token in the batch per sequence in order to take the chunked path
+            auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true);
+            if (!gf) {
+                throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)");
+            }
+
+            const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1;
+            bool gdn_device_mismatch = false;
+            for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
+                ggml_tensor * n = ggml_graph_node(gf, i);
+                if (n->op != GGML_OP_GATED_DELTA_NET) {
+                    continue;
+                }
+                ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n));
+
+                GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0);
+                const int il = std::stoi(n->name + prefix_len);
+                ggml_backend_dev_t device_kv = model.dev_layer(il);
+                if (device_gdn != device_kv) {
+                    LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor "
+                            "is assigned to device %s (usually due to missing support)\n",
+                            __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn));
+                    gdn_device_mismatch = true;
+                    break;
+                }
+            }
+
+            if (gdn_device_mismatch) {
+                cparams.fused_gdn_ch = false;
+                LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__);
+            } else {
+                LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__);
+            }
         }
+
+        cparams.auto_fgdn = false;
     }
 
     // reserve worst-case graph
index 333922468c0b14bf6f6787991e6287bf13308d3b..9d359474132f91e56e0100ef21e967279c098dea 100644 (file)
@@ -33,6 +33,7 @@ struct llama_cparams {
     bool auto_fa;
     bool fused_gdn_ar;       // use fused gated delta net (autoregressive)
     bool fused_gdn_ch;       // use fused gated delta net (chunked)
+    bool auto_fgdn;
     bool no_perf;
     bool warmup;
     bool op_offload;
index ee27ac1beaa655805963cd70eb51155ecf267398..e4f35c8e53d64cde77b8c19af976efdd4ac1767e 100644 (file)
@@ -70,6 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t);
 
 std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
 
-#define LLAMA_TENSOR_NAME_FATTN  "__fattn__"
-#define LLAMA_TENSOR_NAME_FGDNAR "__fgdnar__"
-#define LLAMA_TENSOR_NAME_FGDNCH "__fgdnch__"
+#define LLAMA_TENSOR_NAME_FATTN   "__fattn__"
+#define LLAMA_TENSOR_NAME_FGDN_AR "__fgdn_ar__"
+#define LLAMA_TENSOR_NAME_FGDN_CH "__fgdn_ch__"
index b0be62fc688d5eacc296bf3ac68e7e7f72c4e83f..a62dbc15dd0558a914cb8f36e8e751303d680f22 100644 (file)
@@ -41,13 +41,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
     GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
 
-    if (cparams.fused_gdn_ch) {
-        //ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
-        //cb(result, LLAMA_TENSOR_NAME_FGDNCH, il);
-
-        GGML_ABORT("not implemented yet");
-    }
-
     const float scale = 1.0f / sqrtf(S_k);
 
     q = ggml_scale(ctx0, q, scale);
@@ -325,26 +318,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
     GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
     GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
 
-    if (cparams.fused_gdn_ar) {
-        ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
-        cb(result, LLAMA_TENSOR_NAME_FGDNAR, il);
-
-        ggml_tensor * output = ggml_view_4d(ctx0, result,
-            S_v, H_v, n_tokens, n_seqs,
-            ggml_row_size(result->type, S_v),
-            ggml_row_size(result->type, S_v * H_v),
-            ggml_row_size(result->type, S_v * H_v * n_tokens), 0);
-
-        ggml_tensor * new_state = ggml_view_4d(ctx0, result,
-            S_v, S_v, H_v, n_seqs,
-            ggml_row_size(result->type, S_v),
-            ggml_row_size(result->type, S_v * S_v),
-            ggml_row_size(result->type, S_v * S_v * H_v),
-            ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));
-
-        return {output, new_state};
-    }
-
     const float scale = 1.0f / sqrtf(S_k);
 
     q = ggml_scale(ctx0, q, scale);
@@ -401,3 +374,78 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
 
     return {o, s};
 }
+
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_fused(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b,
+        ggml_tensor * s,
+        int           il) {
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(S_k == S_v);
+    GGML_ASSERT(H_v % H_k == 0);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
+
+    GGML_ASSERT(g->ne[0] == 1   || g->ne[0] == S_v);
+    GGML_ASSERT(                   g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
+    GGML_ASSERT(b->ne[0] == 1   && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
+    GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v      && s->ne[3] == n_seqs);
+
+    ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s);
+    if (n_tokens == 1) {
+        cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il);
+    } else {
+        cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il);
+    }
+
+    ggml_tensor * output = ggml_view_4d(ctx0, result,
+            S_v, H_v, n_tokens, n_seqs,
+            ggml_row_size(result->type, S_v),
+            ggml_row_size(result->type, S_v * H_v),
+            ggml_row_size(result->type, S_v * H_v * n_tokens), 0);
+
+    ggml_tensor * new_state = ggml_view_4d(ctx0, result,
+            S_v, S_v, H_v, n_seqs,
+            ggml_row_size(result->type, S_v),
+            ggml_row_size(result->type, S_v * S_v),
+            ggml_row_size(result->type, S_v * S_v * H_v),
+            ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs));
+
+    return {output, new_state};
+}
+
+std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * g,
+        ggml_tensor * b,
+        ggml_tensor * s,
+        int           il) {
+    const int64_t n_seq_tokens = q->ne[2];
+
+    if (n_seq_tokens == 1) {
+        if (cparams.fused_gdn_ar) {
+            return build_delta_net_fused(q, k, v, g, b, s, il);
+        }
+        return build_delta_net_autoregressive(q, k, v, g, b, s, il);
+    }
+
+    if (cparams.fused_gdn_ch) {
+        return build_delta_net_fused(q, k, v, g, b, s, il);
+    }
+
+    return build_delta_net_chunking(q, k, v, g, b, s, il);
+}
index 063b17a2f6a20a3f75370f77253ca3aaaa4511e0..4d62f4e7159aa5c9dbcd0871b9b4b7576585265f 100644 (file)
@@ -169,9 +169,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
             Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm);
 
             // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
-            std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
-                build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
-                build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il);
+            auto attn_out = build_delta_net(Qcur, Kcur, Vcur, g1, beta, state, il);
 
             ggml_tensor * output = ggml_cont(ctx0, attn_out.first);
             ggml_tensor * new_state = attn_out.second;
index cf9ba04e7f7c148222846ccce3b696aa0eef2917..a86b2b1ebd79cd0dc221c1871753394d3ef14f3a 100644 (file)
@@ -44,6 +44,26 @@ struct llm_build_delta_net_base : public llm_graph_context {
                 ggml_tensor * b,
                 ggml_tensor * s,
                 int           il);
+
+    // use the ggml_gated_delta_net fused operator
+    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                        int   il);
+
+    // choose one of two implementations above based on the number of tokens
+    std::pair<ggml_tensor *, ggml_tensor *> build_delta_net(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * g,
+                ggml_tensor * b,
+                ggml_tensor * s,
+                        int   il);
 };
 
 struct llm_build_rwkv6_base : public llm_graph_context {
index ba096a5a7ba14cac306453230ed76f17d3370b22..e12dad70012c2fbcaa1d4466ed4b07f8baca3e5b 100644 (file)
@@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
     //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
-    if (num_k_heads != num_v_heads) {
+    // note: need explicit repeat only if we are not using the fused GDN
+    if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
-        // TODO: try to avoid these explicit repeats by utilizing op broadcast
         q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
         k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
     }
@@ -332,12 +332,8 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
     cb(k_conv, "k_conv_predelta", il);
     cb(v_conv, "v_conv_predelta", il);
 
-    std::pair<ggml_tensor *, ggml_tensor *> attn_out;
-    if (n_seq_tokens == 1) {
-        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
-    } else {
-        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
-    }
+    auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
+
     ggml_tensor * output    = attn_out.first;
     ggml_tensor * new_state = attn_out.second;
     cb(output, "attn_output", il);
index fe382286e984d6ec9ca48aea8b12b64c888cc5f1..8d07c7ed275c7aa998cfd706d4033b20fed53392 100644 (file)
@@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
     //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
-    if (num_k_heads != num_v_heads) {
+    // note: need explicit repeat only if we are not using the fused GDN
+    if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
-        // TODO: try to avoid these explicit repeats by utilizing op broadcast
         q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
         k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
     }
@@ -332,12 +332,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
     cb(k_conv, "k_conv_predelta", il);
     cb(v_conv, "v_conv_predelta", il);
 
-    std::pair<ggml_tensor *, ggml_tensor *> attn_out;
-    if (n_seq_tokens == 1) {
-        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
-    } else {
-        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
-    }
+    auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
+
     ggml_tensor * output    = attn_out.first;
     ggml_tensor * new_state = attn_out.second;
     cb(output, "attn_output", il);
index 30912fd5e363b3b6324ad65b7be2a8a28b78f18b..cc479dd075c2026a3e6ae41357e552ca65aff310 100644 (file)
@@ -406,6 +406,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
 
     // if head keys and value keys are different, repeat to force tensors into matching shapes
+    // TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST]
     if (num_k_heads != num_v_heads) {
         GGML_ASSERT(num_v_heads % num_k_heads == 0);
         int64_t repeat_factor = num_v_heads / num_k_heads;
@@ -431,13 +432,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
     cb(k_conv, "k_conv_predelta", il);
     cb(v_conv, "v_conv_predelta", il);
 
-    // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
-    std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state)
-    if (n_seq_tokens == 1) {
-        attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
-    } else {
-        attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
-    }
+    auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il);
+
     ggml_tensor * output    = attn_out.first;
     ggml_tensor * new_state = attn_out.second;
     cb(output, "attn_output", il);
index 66aaddcfffd522622b6cec675adf25e236db787f..58d67d97f8d439426d4952066495a4245a52a41a 100644 (file)
@@ -8447,6 +8447,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     }
 
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
@@ -8456,10 +8459,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     // KDA (vector gate)
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 1, 2, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true,  true));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true,  true));
 
 #if 0
     // these tests are disabled to save execution time, sbut they can be handy for debugging