]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
llama : enable chunked fused GDN path (llama/20340)
authorGeorgi Gerganov <redacted>
Wed, 11 Mar 2026 20:46:40 +0000 (22:46 +0200)
committerGeorgi Gerganov <redacted>
Sun, 15 Mar 2026 19:50:13 +0000 (21:50 +0200)
* llama : enable chunked fused GDN path

* models : avoid Q and K repeats when using fused GDA

* cont : fix comment

Co-authored-by: Aman Gupta <redacted>
* cont : fix the fix

Co-authored-by: Aman Gupta <redacted>
* cont : fix

* metal : add GDN kernel (llama/20361)

* metal : add Metal backend for GGML_OP_GATED_DELTA_NET

Add a fused Metal kernel for the gated delta net recurrence op
(#19504), enabling GPU-accelerated inference for DeltaNet-based
models (Qwen3.5, etc.) on Apple Silicon.

Supports both GDA (scalar gate) and KDA (per-row gate) modes
with head_size 64 and 128. Unsupported configurations (head_size
32, non-contiguous tensors) gracefully fall back to CPU.

Performance: Qwen3.5-0.8B Q4_K_M on M4 Max
  tg128: 170 -> 213 t/s (+25%)

Co-Authored-By: Claude Opus 4.6 <redacted>
* metal : validate contiguity of all input tensors in supports_op

Co-Authored-By: Claude Opus 4.6 <redacted>
* metal : add algorithm equivalence comment for GDA decay path

Co-Authored-By: Claude Opus 4.6 <redacted>
* cont : unslop + optimize

* cont : clean-up

---------

Co-authored-by: Paul Flynn <redacted>
Co-authored-by: Claude Opus 4.6 <redacted>
* CUDA: AR gated delta net improvements (llama/20391)

* Add FastDiv to gated_delta_net_cuda

* Shard columns across warps

This reduces register pressure (avoids spill for S_v = 128) and gives
the warp-scheduler more CTAs to schedule (thus hiding data-access
latencies).

* Remove unneded include in gated_delta_net.cu

* Improve comments

* Apply code-formating

* Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA

* Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t

* Rename variables

* Enable GDN also for prefill, move TODO for chunked_GDN

* Actually remove the TODO from 206890897546bd16602c3b79394fd5ea09ef199f

* Get warp size at runtime

warp_size is not known at compile time in hip host code.

* Don't expose ggml_cuda_get_physical_warp_size on host

---------

Co-authored-by: uvos <redacted>
* llama : refactor llm_build_delta_net_base API

---------

Co-authored-by: Aman Gupta <redacted>
Co-authored-by: Paul Flynn <redacted>
Co-authored-by: Claude Opus 4.6 <redacted>
Co-authored-by: Oliver Simons <redacted>
Co-authored-by: uvos <redacted>
include/ggml.h
src/ggml-cpu/ops.cpp
src/ggml-cuda/gated_delta_net.cu
src/ggml-metal/ggml-metal-device.cpp
src/ggml-metal/ggml-metal-device.h
src/ggml-metal/ggml-metal-device.m
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal-ops.cpp
src/ggml-metal/ggml-metal-ops.h
src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index 3323f8e6c3fea805022bd44e9f11bd38da2cdb68..25f9601e9b5e74af84da0d5b1058633e01b8fa38 100644 (file)
@@ -2466,6 +2466,8 @@ extern "C" {
         bool                  lower,
         bool                  uni);
 
+    // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
+    // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
     GGML_API struct ggml_tensor * ggml_gated_delta_net(
             struct ggml_context * ctx,
             struct ggml_tensor  * q,
index f9c4ec16e4b78c5e1cca8915ecc0cd8ae66761c3..fa9d27046b5ea0101c827fce68cec60218b50804 100644 (file)
@@ -10443,8 +10443,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
 
     const float * state_in_base = (const float *)src_state->data;
 
-    const int64_t rq1 = nev1 / neq1;
-    const int64_t rk1 = nev1 / nek1;
+  //const int64_t rq1 = nev1 / neq1;
+  //const int64_t rk1 = nev1 / nek1;
     const int64_t rq3 = nev3 / neq3;
     const int64_t rk3 = nev3 / nek3;
 
@@ -10454,8 +10454,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
         const int64_t iv1 = ir % H; // head_index
         const int64_t iv3 = ir / H; // sequence
 
-        const int64_t iq1 = iv1 / rq1;
-        const int64_t ik1 = iv1 / rk1;
+        const int64_t iq1 = iv1 % neq1;
+        const int64_t ik1 = iv1 % nek1;
 
         const int64_t iq3 = iv3 / rq3;
         const int64_t ik3 = iv3 / rk3;
@@ -10475,7 +10475,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
             const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
 
             const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
-            const float * g_d   =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
+            const float * g_d    =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
 
             if (kda) {
                 for (int64_t i = 0; i < S_v; ++i) {
@@ -10508,7 +10508,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
 
             attn_data += S_v * H; // advance to next token
         }
-
     }
 }
 
index c249bbc86d5a8f83ca11dbba98740de971a327a9..5f0fa8e58dfedf48a9f12a4c8b7bf4d6fdaa2d0e 100644 (file)
@@ -1,36 +1,36 @@
 #include "gated_delta_net.cuh"
-#include "ggml-cuda/common.cuh"
 
 template <int S_v, bool KDA>
-__global__ void __launch_bounds__(S_v, 1)
-gated_delta_net_cuda(const float * q,
-                     const float * k,
-                     const float * v,
-                     const float * g,
-                     const float * beta,
-                     const float * curr_state,
-                     float *       dst,
-                     const int64_t H,
-                     const int64_t n_tokens,
-                     const int64_t n_seqs,
-                     const int64_t sq1,
-                     const int64_t sq2,
-                     const int64_t sq3,
-                     const int64_t sv1,
-                     const int64_t sv2,
-                     const int64_t sv3,
-                     const int64_t sb1,
-                     const int64_t sb2,
-                     const int64_t sb3,
-                     const int64_t rq1,
-                     const int64_t rq3,
-                     const float   scale) {
-    const int64_t h_idx    = blockIdx.x;
-    const int64_t sequence = blockIdx.y;
-    const int     col      = threadIdx.x;  // each thread owns one column
-
-    const int64_t iq1 = h_idx / rq1;
-    const int64_t iq3 = sequence / rq3;
+__global__ void gated_delta_net_cuda(const float * q,
+                                     const float * k,
+                                     const float * v,
+                                     const float * g,
+                                     const float * beta,
+                                     const float * curr_state,
+                                     float *       dst,
+                                     int64_t       H,
+                                     int64_t       n_tokens,
+                                     int64_t       n_seqs,
+                                     int64_t       sq1,
+                                     int64_t       sq2,
+                                     int64_t       sq3,
+                                     int64_t       sv1,
+                                     int64_t       sv2,
+                                     int64_t       sv3,
+                                     int64_t       sb1,
+                                     int64_t       sb2,
+                                     int64_t       sb3,
+                                     const uint3   neqk1_magic,
+                                     const uint3   rq3_magic,
+                                     float         scale) {
+    const uint32_t h_idx    = blockIdx.x;
+    const uint32_t sequence = blockIdx.y;
+    // each warp owns one column, using warp-level primitives to reduce across rows
+    const int      lane     = threadIdx.x;
+    const int      col      = blockIdx.z * blockDim.y + threadIdx.y;
+
+    const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
+    const uint32_t iq3 = fastdiv(sequence, rq3_magic);
 
     const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
     float *       attn_data        = dst;
@@ -41,17 +41,14 @@ gated_delta_net_cuda(const float * q,
     curr_state += state_offset;
     attn_data += (sequence * n_tokens * H + h_idx) * S_v;
 
-    // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
-    // TODO: check optimal path for RDNA1 and RDNA2 devices.
-#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
-    extern __shared__ float s_shared[];
-    float * s = s_shared + col * S_v;
-#else
-    float s[S_v];
-#endif
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
+    static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
+    constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
+    float         s_shard[rows_per_lane];
 #pragma unroll
-    for (int i = 0; i < S_v; i++) {
-        s[i] = curr_state[i * S_v + col];
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i = r * warp_size + lane;
+        s_shard[r]  = curr_state[i * S_v + col];
     }
 
     for (int t = 0; t < n_tokens; t++) {
@@ -69,46 +66,61 @@ gated_delta_net_cuda(const float * q,
             const float g_val = expf(*g_t);
 
             // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
-            float kv_col = 0.0f;
+            float kv_shard = 0.0f;
 #pragma unroll
-            for (int i = 0; i < S_v; i++) {
-                kv_col += s[i] * k_t[i];
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += s_shard[r] * k_t[i];
             }
+            float kv_col = warp_reduce_sum<warp_size>(kv_shard);
 
             // delta[col] = (v[col] - g * kv[col]) * beta
             float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
 
             // fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
             // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
-            float attn_col = 0.0f;
+            float attn_partial = 0.0f;
 #pragma unroll
-            for (int i = 0; i < S_v; i++) {
-                s[i] = g_val * s[i] + k_t[i] * delta_col;
-                attn_col += s[i] * q_t[i];
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = g_val * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
             }
 
-            attn_data[col] = attn_col * scale;
+            float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
         } else {
             // kv[col] = sum_i g[i] * S[i][col] * k[i]
-            float kv_col = 0.0f;
+            float kv_shard = 0.0f;
 #pragma unroll
-            for (int i = 0; i < S_v; i++) {
-                kv_col += expf(g_t[i]) * s[i] * k_t[i];
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
             }
 
+            float kv_col = warp_reduce_sum<warp_size>(kv_shard);
+
             // delta[col] = (v[col] - kv[col]) * beta
             float delta_col = (v_t[col] - kv_col) * beta_val;
 
             // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
             // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
-            float attn_col = 0.0f;
+            float attn_partial = 0.0f;
 #pragma unroll
-            for (int i = 0; i < S_v; i++) {
-                s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
-                attn_col += s[i] * q_t[i];
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
             }
 
-            attn_data[col] = attn_col * scale;
+            float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
         }
 
         attn_data += S_v * H;
@@ -116,8 +128,9 @@ gated_delta_net_cuda(const float * q,
 
     // Write state back to global memory
 #pragma unroll
-    for (int i = 0; i < S_v; i++) {
-        state[i * S_v + col] = s[i];
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i          = r * warp_size + lane;
+        state[i * S_v + col] = s_shard[r];
     }
 }
 
@@ -135,35 +148,43 @@ static void launch_gated_delta_net(
         const float * q_d, const float * k_d, const float * v_d,
         const float * g_d, const float * b_d, const float * s_d,
         float * dst_d,
-        int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
-        int64_t sq1, int64_t sq2, int64_t sq3,
-        int64_t sv1, int64_t sv2, int64_t sv3,
-        int64_t sb1, int64_t sb2, int64_t sb3,
-        int64_t rq1, int64_t rq3,
+        int64_t S_v,   int64_t H, int64_t n_tokens, int64_t n_seqs,
+        int64_t sq1,   int64_t sq2, int64_t sq3,
+        int64_t sv1,   int64_t sv2, int64_t sv3,
+        int64_t sb1,   int64_t sb2, int64_t sb3,
+        int64_t neqk1, int64_t rq3,
         float scale, cudaStream_t stream) {
+    //TODO: Add chunked kernel for even faster pre-fill
+    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
+    const int num_warps = 4;
+    dim3      grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
+    dim3      block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
 
-    dim3 grid_dims(H, n_seqs, 1);
-    dim3 block_dims(S_v, 1, 1);
+    const uint3 neqk1_magic = init_fastdiv_values(neqk1);
+    const uint3 rq3_magic   = init_fastdiv_values(rq3);
 
     int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
     switch (S_v) {
-        case 32: {
-            constexpr int sv = 32;
-            size_t smem = calculate_smem(sv, cc);
-            gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
+        case 16:
+            gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-                sb1, sb2, sb3, rq1, rq3, scale);
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+            break;
+        case 32:
+            gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
+                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
-        }
         case 64: {
             constexpr int sv = 64;
             size_t smem = calculate_smem(sv, cc);
             gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-                sb1, sb2, sb3, rq1, rq3, scale);
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
         }
         case 128: {
@@ -172,7 +193,7 @@ static void launch_gated_delta_net(
             gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
                 q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
                 n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-                sb1, sb2, sb3, rq1, rq3, scale);
+                sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
             break;
         }
         default:
@@ -190,10 +211,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
     ggml_tensor * src_state = dst->src[5];
 
     GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
-    GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
+    GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
     GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
-    GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
-    GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
+    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);
+    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);
 
     const int64_t S_v      = nev0;
     const int64_t H        = nev1;
@@ -202,7 +225,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
 
     const bool kda = (src_g->ne[0] == S_v);
 
-    const int64_t rq1 = nev1 / neq1;
+    GGML_ASSERT(neq1 == nek1);
+    const int64_t neqk1 = neq1;
+
     const int64_t rq3 = nev3 / neq3;
 
     const float * q_d = (const float *) src_q->data;
@@ -241,10 +266,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
     if (kda) {
         launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
             S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-            sb1, sb2, sb3, rq1, rq3, scale, stream);
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
     } else {
         launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
             S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
-            sb1, sb2, sb3, rq1, rq3, scale, stream);
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
     }
 }
index 169c63dd7a457bb4ccdbd421669ac9c49434d8f8..15ae2e517df75ac293524bac10a452c35c8f9f82 100644 (file)
@@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
+    char base[256];
+    char name[256];
+
+    // v is src[2], dimensions: S_v = ne[0], H = ne[1]
+    const int ne20 = op->src[2]->ne[0]; // S_v
+    const int ne21 = op->src[2]->ne[1]; // H
+    const int ne30 = op->src[3]->ne[0]; // G
+
+    const int nsg = op->src[2]->ne[0]/32;
+
+    GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->ne[0] == ne20 * ne21);
+    GGML_ASSERT(ne20 % 32 == 0);
+
+    snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
+    snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
+
+        ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
+        ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);
+
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+
+        ggml_metal_cv_free(cv);
+    }
+
+    res.nsg = nsg;
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
     char base[256];
     char name[256];
index 93d7f6a216f976bf7f251ffde0add5dbf338e2a2..fd2b3ddeb5589f3d33a36d3f81f979a1558fcff4 100644 (file)
@@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched  (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net   (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri         (ggml_metal_library_t lib, const struct ggml_tensor * op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, const struct ggml_tensor * op);
index d42b8ab1eb1874a2d0deece92697f7c8d5d36b21..05b826a61b833f217d00b8f0ac803fe3226305d6 100644 (file)
@@ -1155,6 +1155,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
             return true;
+        case GGML_OP_GATED_DELTA_NET:
+            return op->src[2]->ne[0] % 32 == 0;
         case GGML_OP_SOLVE_TRI:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
index 99d64efc3b54319aeebd7c80386f728155a2858d..53437b23cda7bf818bccf70e8fa015d2fa653309 100644 (file)
@@ -84,6 +84,7 @@
 #define FC_BIN                         1300
 #define FC_SUM_ROWS                    1400
 #define FC_UPSCALE                     1500
+#define FC_GATED_DELTA_NET             1600
 
 // op-specific constants
 #define OP_FLASH_ATTN_EXT_NQPSG 8
@@ -793,6 +794,44 @@ typedef struct {
     uint64_t nb0;
 } ggml_metal_kargs_ssm_scan;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne20;
+    int32_t  ne21;
+    int32_t  ne22;
+    int32_t  ne23;
+    uint64_t nb20;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb23;
+    int32_t  ns02;
+    int32_t  ns12;
+    int32_t  ns22;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_gated_delta_net;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
index 267755d08ccf59277f90269d67f91ad2dea9d7cf..306dbcf36609d303bd7fb4652c62dfad148e54e3 100644 (file)
@@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_rwkv(ctx, idx);
             } break;
+        case GGML_OP_GATED_DELTA_NET:
+            {
+                n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
+            } break;
         case GGML_OP_SOLVE_TRI:
             {
                 n_fuse = ggml_metal_op_solve_tri(ctx, idx);
@@ -1562,6 +1566,81 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);
+
+    int ida = 0;
+
+    ggml_metal_kargs_gated_delta_net args = {
+        /*.ne00 =*/ ne00,
+        /*.ne01 =*/ ne01,
+        /*.ne02 =*/ ne02,
+        /*.ne03 =*/ ne03,
+        /*.nb00 =*/ nb00,
+        /*.nb01 =*/ nb01,
+        /*.nb02 =*/ nb02,
+        /*.nb03 =*/ nb03,
+        /*.ne10 =*/ ne10,
+        /*.ne11 =*/ ne11,
+        /*.ne12 =*/ ne12,
+        /*.ne13 =*/ ne13,
+        /*.nb10 =*/ nb10,
+        /*.nb11 =*/ nb11,
+        /*.nb12 =*/ nb12,
+        /*.nb13 =*/ nb13,
+        /*.ne20 =*/ ne20,
+        /*.ne21 =*/ ne21,
+        /*.ne22 =*/ ne22,
+        /*.ne23 =*/ ne23,
+        /*.nb20 =*/ nb20,
+        /*.nb21 =*/ nb21,
+        /*.nb22 =*/ nb22,
+        /*.nb23 =*/ nb23,
+        /*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
+        /*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
+        /*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
+        /*.ne0  =*/ ne0,
+        /*.ne1  =*/ ne1,
+        /*.ne2  =*/ ne2,
+        /*.ne3  =*/ ne3,
+        /*.nb0  =*/ nb0,
+        /*.nb1  =*/ nb1,
+        /*.nb2  =*/ nb2,
+        /*.nb3  =*/ nb3,
+    };
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args),                  ida++);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         ida++); // dst
+
+    const int nsg = pipeline.nsg;
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 
index f3e38c7aa9de287316a7c429bb17a718e212e11c..019f2fec9edded7c9a74d35f27e9d5b94603e662 100644 (file)
@@ -58,6 +58,7 @@ int ggml_metal_op_soft_max          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rwkv              (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_gated_delta_net   (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_solve_tri         (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_set               (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cpy               (ggml_metal_op_t ctx, int idx);
index 29e4a245d5d6aee537d2fee9f807706df98c02bd..0b77d5349b861f03101c766d6f5490ca0ecc24a4 100644 (file)
@@ -2434,6 +2434,227 @@ kernel void kernel_rwkv_wkv7_f32(
     }
 }
 
+constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
+constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];
+
+#if 1
+template<short NSG>
+kernel void kernel_gated_delta_net_impl(
+        constant ggml_metal_kargs_gated_delta_net & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * g,
+        device const char * b,
+        device const char * s,
+        device       char * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]])  {
+#define S_v FC_gated_delta_net_ne20
+#define G   FC_gated_delta_net_ne30
+
+    const uint tx = tpitg.x;
+    const uint ty = tpitg.y;
+
+    const uint i23 = tgpig.z; // B
+    const uint i21 = tgpig.y; // H
+    const uint i20 = tgpig.x*NSG + ty;
+
+    const uint i01 = i21 % args.ne01;
+    const uint i11 = i21 % args.ne11;
+
+    const float scale = 1.0f / sqrt((float)S_v);
+
+    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+    float ls[NSG];
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        ls[j] = s_ptr[is*S_v];
+    }
+
+    device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
+
+    device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
+    device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
+    device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
+
+    device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
+    device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
+
+    for (short t = 0; t < args.ne22; t++) {
+        float s_k = 0.0f;
+
+        if (G == 1) {
+            const float g_exp = exp(g_ptr[0]);
+
+            FOR_UNROLL (short j = 0; j < NSG; j++) {
+                const short is = tx*NSG + j;
+                ls[j] *= g_exp;
+
+                s_k += ls[j]*k_ptr[is];
+            }
+        } else {
+            // KDA
+            FOR_UNROLL (short j = 0; j < NSG; j++) {
+                const short is = tx*NSG + j;
+                ls[j] *= exp(g_ptr[is]);
+
+                s_k += ls[j]*k_ptr[is];
+            }
+        }
+
+        s_k = simd_sum(s_k);
+
+        const float d = (v_ptr[i20] - s_k)*b_ptr[0];
+
+        float y = 0.0f;
+
+        FOR_UNROLL (short j = 0; j < NSG; j++) {
+            const short is = tx*NSG + j;
+            ls[j] += k_ptr[is]*d;
+
+            y += ls[j]*q_ptr[is];
+        }
+
+        y = simd_sum(y);
+
+        if (tx == 0) {
+            dst_attn[t*args.ne21*S_v] = y*scale;
+        }
+
+        q_ptr += args.ns02;
+        k_ptr += args.ns12;
+        v_ptr += args.ns22;
+
+        b_ptr += args.ne21;
+        g_ptr += args.ne21*G;
+    }
+
+    device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        dst_state[is*S_v] = ls[j];
+    }
+
+#undef S_v
+#undef G
+}
+
+typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;
+
+template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
+template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
+template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;
+
+#else
+// a simplified version of the above
+// no performance improvement, so keep the above version for now
+
+template<typename T, short NSG>
+kernel void kernel_gated_delta_net_impl(
+        constant ggml_metal_kargs_gated_delta_net & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * g,
+        device const char * b,
+        device const char * s,
+        device       char * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]])  {
+#define S_v FC_gated_delta_net_ne20
+#define G   FC_gated_delta_net_ne30
+
+    const uint tx = tpitg.x;
+    const uint ty = tpitg.y;
+
+    const uint i23 = tgpig.z; // B
+    const uint i21 = tgpig.y; // H
+    const uint i20 = tgpig.x*NSG + ty;
+
+    const uint i01 = i21 % args.ne01;
+    const uint i11 = i21 % args.ne11;
+
+    const float scale = 1.0f / sqrt((float)S_v);
+
+    device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
+
+    float lsf[NSG];
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        lsf[j] = s_ptr[is*S_v];
+    }
+
+    thread T * ls = (thread T *) (lsf);
+
+    device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
+
+    device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
+    device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
+    device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);
+
+    device const float * b_ptr  = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
+    device const float * g_ptr  = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;
+
+    for (short t = 0; t < args.ne22; t++) {
+        device const T * qt_ptr = (device const T *) (q_ptr);
+        device const T * kt_ptr = (device const T *) (k_ptr);
+        device const T * gt_ptr = (device const T *) (g_ptr);
+
+        if (G == 1) {
+            *ls *= exp(g_ptr[0]);
+        } else {
+            // KDA
+            *ls *= exp(gt_ptr[tx]);
+        }
+
+        const float s_k = simd_sum(dot(*ls, kt_ptr[tx]));
+
+        const float d = (v_ptr[i20] - s_k)*b_ptr[0];
+
+        *ls += kt_ptr[tx]*d;
+
+        const float y = simd_sum(dot(*ls, qt_ptr[tx]));
+
+        if (tx == 0) {
+            *dst_attn = y*scale;
+        }
+
+        q_ptr += args.ns02;
+        k_ptr += args.ns12;
+        v_ptr += args.ns22;
+
+        b_ptr += args.ne21;
+        g_ptr += args.ne21*G;
+
+        dst_attn += args.ne21*S_v;
+    }
+
+    device float * dst_state  = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
+    device T     * dstt_state = (device T     *) (dst_state);
+
+    FOR_UNROLL (short j = 0; j < NSG; j++) {
+        const short is = tx*NSG + j;
+        dst_state[is*S_v] = lsf[j];
+    }
+
+#undef S_v
+#undef G
+}
+
+typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t;
+
+template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float,  1>;
+template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>;
+template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>;
+#endif
+
 constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
 constant short FC_solve_tri_n   [[function_constant(FC_SOLVE_TRI + 1)]];
 constant short FC_solve_tri_k   [[function_constant(FC_SOLVE_TRI + 2)]];
index 66aaddcfffd522622b6cec675adf25e236db787f..58d67d97f8d439426d4952066495a4245a52a41a 100644 (file)
@@ -8447,6 +8447,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     }
 
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
@@ -8456,10 +8459,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     // KDA (vector gate)
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 1, 2, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true));
     test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true,  true));
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true,  true));
 
 #if 0
     // these tests are disabled to save execution time, sbut they can be handy for debugging