]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml: add GATED_DELTA_NET op (llama/19504)
authorAman Gupta <redacted>
Sat, 7 Mar 2026 07:41:10 +0000 (15:41 +0800)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* ggml: add GATED_DELTA_NET op

* remove the transpose

* add KDA

* add qwen35 dense

* llama : check for fused gated delta net backend support

---------

Co-authored-by: Georgi Gerganov <redacted>
ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/ops.h
ggml/src/ggml-cuda/gated_delta_net.cu [new file with mode: 0644]
ggml/src/ggml-cuda/gated_delta_net.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml.c

index 784d69206b4a174972a0b79223dd39e1de728892..566e2714790c18ddf90592325f9359f2bf88b062 100644 (file)
@@ -556,6 +556,7 @@ extern "C" {
         GGML_OP_GATED_LINEAR_ATTN,
         GGML_OP_RWKV_WKV7,
         GGML_OP_SOLVE_TRI,
+        GGML_OP_GATED_DELTA_NET,
 
         GGML_OP_UNARY,
 
@@ -2463,6 +2464,15 @@ extern "C" {
         bool                  lower,
         bool                  uni);
 
+    GGML_API struct ggml_tensor * ggml_gated_delta_net(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * g,
+            struct ggml_tensor  * beta,
+            struct ggml_tensor  * state);
+
     // custom operators
 
     typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
index 7c4026fac4e067ca9aafb146a9760ff1dc337988..dc2b5ffaa77310b9656247cd2218e861b45cc17c 100644 (file)
@@ -2021,6 +2021,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_solve_tri(params, tensor);
             } break;
+        case GGML_OP_GATED_DELTA_NET:
+            {
+                ggml_compute_forward_gated_delta_net(params, tensor);
+            } break;
         case GGML_OP_MAP_CUSTOM1:
             {
                 ggml_compute_forward_map_custom1(params, tensor);
@@ -2200,6 +2204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
             } break;
         case GGML_OP_COUNT_EQUAL:
         case GGML_OP_SOLVE_TRI:
+        case GGML_OP_GATED_DELTA_NET:
             {
                 n_tasks = n_threads;
             } break;
@@ -2905,6 +2910,11 @@ struct ggml_cplan ggml_graph_plan(
                     {
                         cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
                     } break;
+                case GGML_OP_GATED_DELTA_NET:
+                    {
+                        const int64_t S_v = node->src[2]->ne[0];
+                        cur = S_v * sizeof(float) * n_tasks;
+                    } break;
                 case GGML_OP_COUNT:
                     {
                         GGML_ABORT("fatal error");
index 2c372f9635b0eee5140b2ba670023aaa330b8d08..331e071a2677a8c8e2bcedcffb59d5a6efbe0288 100644 (file)
@@ -10380,6 +10380,190 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s
     }
 }
 
+// ggml_compute_forward_gated_delta_net
+static void ggml_compute_forward_gated_delta_net_one_chunk(
+    const ggml_compute_params * params,
+    ggml_tensor * dst,
+    int64_t ir0,
+    int64_t ir1) {
+
+    ggml_tensor * src_q     = dst->src[0];
+    ggml_tensor * src_k     = dst->src[1];
+    ggml_tensor * src_v     = dst->src[2];
+    ggml_tensor * src_g     = dst->src[3];
+    ggml_tensor * src_beta  = dst->src[4];
+    ggml_tensor * src_state = dst->src[5];
+
+    const int64_t S_v      = src_v->ne[0];
+    const int64_t H        = src_v->ne[1];
+    const int64_t n_tokens = src_v->ne[2];
+    const int64_t n_seqs   = src_v->ne[3];
+
+    GGML_ASSERT(ggml_is_contiguous_rows(src_q));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_k));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_v));
+    GGML_ASSERT(ggml_is_contiguous(src_g));
+    GGML_ASSERT(ggml_is_contiguous(src_beta));
+    GGML_ASSERT(ggml_is_contiguous(src_state));
+
+    GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v);
+    GGML_ASSERT(src_beta->ne[0] == 1);
+
+    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbk, src_k, nb);
+    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);
+    GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbg, src_g, nb);
+    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);
+
+    const bool kda = (neg0 == S_v);
+
+    // scratch layout per thread: [delta(S_v)]
+    const int64_t scratch_per_thread = S_v;
+    const int ith = params->ith;
+
+    float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32;
+
+    // output layout: [attn_scores | new_states]
+    // attn_scores: S_v * H * n_tokens * n_seqs floats
+    // new_states:  S_v * S_v * H * n_seqs floats
+    const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+    float * attn_out_base  = (float *)dst->data;
+    float * state_out_base = (float *)dst->data + attn_score_elems;
+
+    const float * state_in_base = (const float *)src_state->data;
+
+    const int64_t rq1 = nev1 / neq1;
+    const int64_t rk1 = nev1 / nek1;
+    const int64_t rq3 = nev3 / neq3;
+    const int64_t rk3 = nev3 / nek3;
+
+    const float scale = 1.0f / sqrtf((float) S_v);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t iv1 = ir % H; // head_index
+        const int64_t iv3 = ir / H; // sequence
+
+        const int64_t iq1 = iv1 / rq1;
+        const int64_t ik1 = iv1 / rk1;
+
+        const int64_t iq3 = iv3 / rq3;
+        const int64_t ik3 = iv3 / rk3;
+
+        float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v;
+
+        // copy input state into output buffer and operate in-place
+        const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v;
+        memcpy(s_out, s_in, S_v * S_v * sizeof(float));
+
+        // attn output pointer for first token of this (head, seq)
+        float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v;
+
+        for (int64_t t = 0; t < n_tokens; t++) {
+            const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1);
+            const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1);
+            const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);
+
+            const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
+            const float * g_d   =  (const float *)((const char *)src_g->data    + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
+
+            if (kda) {
+                for (int64_t i = 0; i < S_v; ++i) {
+                    ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i]));
+                }
+            } else {
+                ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
+            }
+
+            // delta[j] = sum_i S[j][i] * k[i]
+            memset(delta, 0, S_v * sizeof(float));
+            for (int64_t i = 0; i < S_v; ++i) {
+                ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]);
+            }
+            for (int64_t j = 0; j < S_v; ++j) {
+                delta[j] = (v_d[j] - delta[j]) * beta_val;
+            }
+
+            // outer product: S[j][i] += k[i] * delta[j]
+            for (int64_t i = 0; i < S_v; ++i) {
+                ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]);
+            }
+
+            // attn_out[j] = sum_i S[j][i] * q[i]
+            memset(attn_data, 0, S_v * sizeof(float));
+            for (int64_t i = 0; i < S_v; ++i) {
+                ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]);
+            }
+            ggml_vec_scale_f32(S_v, attn_data, scale);
+
+            attn_data += S_v * H; // advance to next token
+        }
+
+    }
+}
+
+
+static void ggml_compute_forward_gated_delta_net_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    ggml_tensor * V = dst->src[2];
+    int64_t nr = V->ne[1] * V->ne[3];
+
+    // disable for NUMA
+    const bool disable_chunking = ggml_is_numa();
+
+    int nth = params->nth;
+    int ith = params->ith;
+
+    // 4x chunks per thread
+    int nth_scaled = nth * 4;
+    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
+    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
+
+    if (nth == 1 || nchunk < nth || disable_chunking) {
+      nchunk = nth;
+    }
+
+    if (ith == 0) {
+      ggml_threadpool_chunk_set(params->threadpool, nth);
+    }
+
+    ggml_barrier(params->threadpool);
+
+    const int64_t dr = (nr + nchunk - 1) / nchunk;
+
+    int current_chunk = ith;
+
+    while (current_chunk < nchunk) {
+        const int64_t ir0 = dr * current_chunk;
+        const int64_t ir1 = MIN(ir0 + dr, nr);
+
+        ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1);
+        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+    }
+}
+
+void ggml_compute_forward_gated_delta_net(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gated_delta_net_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_rwkv_wkv7
 
 static void ggml_compute_forward_rwkv_wkv7_f32(
index 0fdfee79766e4aca15ec0c6f1afe1abe3645c8cf..3fa1443abc48401537c3d4da3126c90c8a6e7156 100644 (file)
@@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s
 void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu
new file mode 100644 (file)
index 0000000..d8e8111
--- /dev/null
@@ -0,0 +1,223 @@
+#include "gated_delta_net.cuh"
+#include "ggml-cuda/common.cuh"
+
+template <int S_v, bool KDA>
+__global__ void gated_delta_net_cuda(const float * q,
+                                     const float * k,
+                                     const float * v,
+                                     const float * g,
+                                     const float * beta,
+                                     const float * curr_state,
+                                     float *       dst,
+                                     int64_t       H,
+                                     int64_t       n_tokens,
+                                     int64_t       n_seqs,
+                                     int64_t       sq1,
+                                     int64_t       sq2,
+                                     int64_t       sq3,
+                                     int64_t       sv1,
+                                     int64_t       sv2,
+                                     int64_t       sv3,
+                                     int64_t       sb1,
+                                     int64_t       sb2,
+                                     int64_t       sb3,
+                                     int64_t       rq1,
+                                     int64_t       rq3,
+                                     float         scale) {
+    const int64_t h_idx    = blockIdx.x;
+    const int64_t sequence = blockIdx.y;
+    const int     col      = threadIdx.x;  // each thread owns one column
+
+    const int64_t iq1 = h_idx / rq1;
+    const int64_t iq3 = sequence / rq3;
+
+    const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+    float *       attn_data        = dst;
+    float *       state            = dst + attn_score_elems;
+
+    const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
+    state += state_offset;
+    curr_state += state_offset;
+    attn_data += (sequence * n_tokens * H + h_idx) * S_v;
+
+    // Load state column into registers
+    float s[S_v];
+#pragma unroll
+    for (int i = 0; i < S_v; i++) {
+        s[i] = curr_state[i * S_v + col];
+    }
+
+    for (int t = 0; t < n_tokens; t++) {
+        const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
+
+        const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;
+        const float * beta_t = beta + gb_offset;
+        const float * g_t    = g    + gb_offset * (KDA ? S_v : 1);
+
+        const float beta_val = *beta_t;
+
+        if constexpr (!KDA) {
+            const float g_val = expf(*g_t);
+
+            // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
+            float kv_col = 0.0f;
+#pragma unroll
+            for (int i = 0; i < S_v; i++) {
+                kv_col += s[i] * k_t[i];
+            }
+
+            // delta[col] = (v[col] - g * kv[col]) * beta
+            float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
+
+            // fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
+            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+            float attn_col = 0.0f;
+#pragma unroll
+            for (int i = 0; i < S_v; i++) {
+                s[i] = g_val * s[i] + k_t[i] * delta_col;
+                attn_col += s[i] * q_t[i];
+            }
+
+            attn_data[col] = attn_col * scale;
+        } else {
+            // kv[col] = sum_i g[i] * S[i][col] * k[i]
+            float kv_col = 0.0f;
+#pragma unroll
+            for (int i = 0; i < S_v; i++) {
+                kv_col += expf(g_t[i]) * s[i] * k_t[i];
+            }
+
+            // delta[col] = (v[col] - kv[col]) * beta
+            float delta_col = (v_t[col] - kv_col) * beta_val;
+
+            // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
+            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+            float attn_col = 0.0f;
+#pragma unroll
+            for (int i = 0; i < S_v; i++) {
+                s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col;
+                attn_col += s[i] * q_t[i];
+            }
+
+            attn_data[col] = attn_col * scale;
+        }
+
+        attn_data += S_v * H;
+    }
+
+    // Write state back to global memory
+#pragma unroll
+    for (int i = 0; i < S_v; i++) {
+        state[i * S_v + col] = s[i];
+    }
+}
+
+template <bool KDA>
+static void launch_gated_delta_net(
+        const float * q_d, const float * k_d, const float * v_d,
+        const float * g_d, const float * b_d, const float * s_d,
+        float * dst_d,
+        int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
+        int64_t sq1, int64_t sq2, int64_t sq3,
+        int64_t sv1, int64_t sv2, int64_t sv3,
+        int64_t sb1, int64_t sb2, int64_t sb3,
+        int64_t rq1, int64_t rq3,
+        float scale, cudaStream_t stream) {
+
+    dim3 grid_dims(H, n_seqs, 1);
+    dim3 block_dims(S_v, 1, 1);
+
+    switch (S_v) {
+        case 32:
+            gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
+                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+                sb1, sb2, sb3, rq1, rq3, scale);
+            break;
+        case 64:
+            gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
+                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+                sb1, sb2, sb3, rq1, rq3, scale);
+            break;
+        case 128:
+            gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
+                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
+                n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+                sb1, sb2, sb3, rq1, rq3, scale);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * src_q     = dst->src[0];
+    ggml_tensor * src_k     = dst->src[1];
+    ggml_tensor * src_v     = dst->src[2];
+    ggml_tensor * src_g     = dst->src[3];
+    ggml_tensor * src_beta  = dst->src[4];
+    ggml_tensor * src_state = dst->src[5];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
+    GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
+    GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
+    GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
+
+    const int64_t S_v      = nev0;
+    const int64_t H        = nev1;
+    const int64_t n_tokens = nev2;
+    const int64_t n_seqs   = nev3;
+
+    const bool kda = (src_g->ne[0] == S_v);
+
+    const int64_t rq1 = nev1 / neq1;
+    const int64_t rq3 = nev3 / neq3;
+
+    const float * q_d = (const float *) src_q->data;
+    const float * k_d = (const float *) src_k->data;
+    const float * v_d = (const float *) src_v->data;
+    const float * g_d = (const float *) src_g->data;
+    const float * b_d = (const float *) src_beta->data;
+
+    const float * s_d   = (const float *) src_state->data;
+    float *       dst_d = (float *) dst->data;
+
+    GGML_ASSERT(ggml_is_contiguous_rows(src_q));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_k));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_v));
+    GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
+    GGML_ASSERT(src_g->ne[0] == 1 || kda);
+    GGML_ASSERT(ggml_is_contiguous(src_g));
+    GGML_ASSERT(ggml_is_contiguous(src_beta));
+    GGML_ASSERT(ggml_is_contiguous(src_state));
+
+    // strides in floats (beta strides used for both g and beta offset computation)
+    const int64_t sq1 = nbq1 / sizeof(float);
+    const int64_t sq2 = nbq2 / sizeof(float);
+    const int64_t sq3 = nbq3 / sizeof(float);
+    const int64_t sv1 = nbv1 / sizeof(float);
+    const int64_t sv2 = nbv2 / sizeof(float);
+    const int64_t sv3 = nbv3 / sizeof(float);
+    const int64_t sb1 = nbb1 / sizeof(float);
+    const int64_t sb2 = nbb2 / sizeof(float);
+    const int64_t sb3 = nbb3 / sizeof(float);
+
+    const float scale = 1.0f / sqrtf((float) S_v);
+
+    cudaStream_t stream = ctx.stream();
+
+    if (kda) {
+        launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+            sb1, sb2, sb3, rq1, rq3, scale, stream);
+    } else {
+        launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+            sb1, sb2, sb3, rq1, rq3, scale, stream);
+    }
+}
diff --git a/ggml/src/ggml-cuda/gated_delta_net.cuh b/ggml/src/ggml-cuda/gated_delta_net.cuh
new file mode 100644 (file)
index 0000000..7375e81
--- /dev/null
@@ -0,0 +1,4 @@
+#include "common.cuh"
+#include "ggml.h"
+
+void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 54dc43bc088ddcabda30e797efe5107cabddba01..a8007a06360be1d915a5b87c942534d98ded4a96 100644 (file)
@@ -53,6 +53,7 @@
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv.cuh"
 #include "ggml-cuda/gla.cuh"
+#include "ggml-cuda/gated_delta_net.cuh"
 #include "ggml-cuda/set.cuh"
 #include "ggml-cuda/set-rows.cuh"
 #include "ggml-cuda/pad_reflect_1d.cuh"
@@ -2733,6 +2734,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GATED_LINEAR_ATTN:
             ggml_cuda_op_gated_linear_attn(ctx, dst);
             break;
+        case GGML_OP_GATED_DELTA_NET:
+            ggml_cuda_op_gated_delta_net(ctx, dst);
+            break;
         case GGML_OP_RWKV_WKV7:
             ggml_cuda_op_rwkv_wkv7(ctx, dst);
             break;
@@ -4972,6 +4976,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_GATED_LINEAR_ATTN:
+        case GGML_OP_GATED_DELTA_NET:
         case GGML_OP_RWKV_WKV7:
             return true;
         case GGML_OP_FLASH_ATTN_EXT:
index d644cca8a6e4f0c1b335678f4446ca439d090374..aeafc395d7100977889568753db701deab7732cd 100644 (file)
@@ -1031,6 +1031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GATED_LINEAR_ATTN",
     "RWKV_WKV7",
     "SOLVE_TRI",
+    "GATED_DELTA_NET",
 
     "UNARY",
 
@@ -1048,7 +1049,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
+static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1140,6 +1141,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "gated_linear_attn(k, v, q, gate, s)",
     "rwkv_wkv7(r, w, k, v, a, b, s)",
     "A X = B, A triangular, solve X",
+    "gated_delta_net(q, k, v, g, beta, s)",
 
     "unary(x)",
 
@@ -1157,7 +1159,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95");
+static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -6124,6 +6126,57 @@ struct ggml_tensor * ggml_solve_tri(
     return result;
 }
 
+// ggml_gated_delta_net
+
+struct ggml_tensor * ggml_gated_delta_net(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * beta,
+        struct ggml_tensor  * state) {
+    GGML_ASSERT(ggml_is_contiguous_rows(q));
+    GGML_ASSERT(ggml_is_contiguous_rows(k));
+    GGML_ASSERT(ggml_is_contiguous_rows(v));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(beta));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    GGML_ASSERT(q->type == GGML_TYPE_F32);
+    GGML_ASSERT(k->type == GGML_TYPE_F32);
+    GGML_ASSERT(v->type == GGML_TYPE_F32);
+    GGML_ASSERT(g->type == GGML_TYPE_F32);
+    GGML_ASSERT(beta->type == GGML_TYPE_F32);
+    GGML_ASSERT(state->type == GGML_TYPE_F32);
+
+    const int64_t S_v      = v->ne[0];
+    const int64_t H        = v->ne[1];
+    const int64_t n_tokens = v->ne[2];
+    const int64_t n_seqs   = v->ne[3];
+
+    // gate: scalar [1, H, T, B] or vector [S_v, H, T, B] (KDA)
+    GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
+    GGML_ASSERT(beta->ne[0] == 1);
+
+    GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs);
+
+    // concat output and new_state into a single tensor
+    // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs
+    const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    result->op     = GGML_OP_GATED_DELTA_NET;
+    result->src[0] = q;
+    result->src[1] = k;
+    result->src[2] = v;
+    result->src[3] = g;
+    result->src[4] = beta;
+    result->src[5] = state;
+
+    return result;
+}
+
 ////////////////////////////////////////////////////////////////////////////////
 
 struct ggml_hash_set ggml_hash_set_new(size_t size) {