]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: add GATED_DELTA_NET op support (llama/20334)
authorProgenyAlpha <redacted>
Thu, 12 Mar 2026 10:32:04 +0000 (06:32 -0400)
committerGeorgi Gerganov <redacted>
Sun, 15 Mar 2026 19:50:13 +0000 (21:50 +0200)
* vulkan: add GATED_DELTA_NET op support

Implements the fused gated delta net recurrence as a Vulkan compute
shader with full support for scalar gate, KDA vector gate, GQA
broadcast, multi-token sequences, and permuted (non-contiguous) q/k
inputs. Specialization constants select head size (32/64/128) and
KDA mode at pipeline creation time.

Passes all 13 test-backend-ops cases on AMD Radeon 890M (RADV GFX1150).

Co-Authored-By: Claude Opus 4.6 <redacted>
* vulkan: optimize GATED_DELTA_NET shader (Phase 1)

- vec4 dot products on all inner loops (dp4 hardware intrinsic)
- Cache exp(g) in shared memory for KDA path, eliminating ~32K
  redundant global reads and ~16K redundant exp() calls per token
- vec4 fused decay + rank-1 update (3 vec4 ops vs 12 scalar ops)
- Add perf benchmark cases for GATED_DELTA_NET to test-backend-ops

KDA TG: +5.4% throughput. Non-KDA: no regressions.
13/13 test-backend-ops passing on AMD Radeon 890M (RADV GFX1150).

Co-Authored-By: Claude Opus 4.6 <redacted>
* vulkan: address review feedback for GATED_DELTA_NET

Pipeline array refactor [3][2], A_TYPE/D_TYPE/FLOAT_TYPE shader macros,
scale in push constants, supports_op fix, dispatch restructuring.

Co-Authored-By: Claude Opus 4.6 <redacted>
* vulkan: use FLOAT_TYPE for buffer/shared declarations, align formatting

Co-Authored-By: Claude Opus 4.6 <redacted>
* vulkan: add explicit FLOAT_TYPE casts for buffer loads

Wrap data_q, data_k, and data_g buffer reads with FLOAT_TYPE() casts
to ensure correct behavior across all Vulkan configurations.

Co-Authored-By: Claude Opus 4.6 <redacted>
* vulkan: fix Q/K broadcast for interleaved head layout

Adapt to the interleaved broadcast convention from #20340:
head_id / rq1 → head_id % neq1

Co-Authored-By: Claude Opus 4.6 <redacted>
---------

Co-authored-by: Progeny Alpha <redacted>
Co-authored-by: Claude Opus 4.6 <redacted>
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp

index 2a2f7f4f11c86baf517e68ff6971c3b9452add25..3c81805b844ce6194c5b320d96959a3fcdfff39a 100644 (file)
@@ -825,6 +825,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_pool2d_f32;
     vk_pipeline pipeline_rwkv_wkv6_f32;
     vk_pipeline pipeline_rwkv_wkv7_f32;
+    // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
+    vk_pipeline pipeline_gated_delta_net[3][2];
     vk_pipeline pipeline_ssm_scan_f32_d128;
     vk_pipeline pipeline_ssm_scan_f32_d256;
     vk_pipeline pipeline_ssm_conv_f32;
@@ -1454,6 +1456,18 @@ struct vk_op_rwkv_wkv7_push_constants {
     uint32_t C;
     uint32_t H;
 };
+struct vk_op_gated_delta_net_push_constants {
+    uint32_t H;
+    uint32_t n_tokens;
+    uint32_t n_seqs;
+    uint32_t s_off;
+    uint32_t sq1, sq2, sq3;
+    uint32_t sv1, sv2, sv3;
+    uint32_t sb1, sb2, sb3;
+    uint32_t neq1, rq3;
+    float scale;
+};
+
 struct vk_op_ssm_scan_push_constants {
     uint32_t nb02, nb03, nb12, nb13;
     uint32_t nb21, nb22, nb31;
@@ -4568,6 +4582,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 
+    {
+        const uint32_t gdn_sizes[] = {32, 64, 128};
+        const char * gdn_names[][2] = {
+            {"gated_delta_net_f32_d32",     "gated_delta_net_f32_d32_kda"},
+            {"gated_delta_net_f32_d64",     "gated_delta_net_f32_d64_kda"},
+            {"gated_delta_net_f32_d128",    "gated_delta_net_f32_d128_kda"},
+        };
+        for (uint32_t si = 0; si < 3; si++) {
+            for (uint32_t kda = 0; kda < 2; kda++) {
+                ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
+                    gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
+                    "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
+                    {1, 1, 1}, {gdn_sizes[si], kda}, 1);
+            }
+        }
+    }
+
     if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
         ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
@@ -9498,6 +9529,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_rwkv_wkv7_f32;
         }
         return nullptr;
+    case GGML_OP_GATED_DELTA_NET:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            const uint32_t S_v = dst->src[2]->ne[0];
+            const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
+            uint32_t si;
+            switch (S_v) {
+                case 32:  si = 0; break;
+                case 64:  si = 1; break;
+                case 128: si = 2; break;
+                default: return nullptr;
+            }
+            return ctx->device->pipeline_gated_delta_net[si][kda];
+        }
+        return nullptr;
     case GGML_OP_SSM_SCAN:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             const uint32_t d_state = src0->ne[0];
@@ -10328,6 +10373,59 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
     );
 }
 
+static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
+    const ggml_tensor * src_q     = dst->src[0];
+    const ggml_tensor * src_v     = dst->src[2];
+    const ggml_tensor * src_beta  = dst->src[4];
+
+    GGML_ASSERT(dst->buffer != nullptr);
+
+    const uint32_t S_v      = (uint32_t)src_v->ne[0];
+    const uint32_t H        = (uint32_t)src_v->ne[1];
+    const uint32_t n_tokens = (uint32_t)src_v->ne[2];
+    const uint32_t n_seqs   = (uint32_t)src_v->ne[3];
+
+    const uint32_t s_off = S_v * H * n_tokens * n_seqs;
+
+    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
+    GGML_ASSERT(pipeline != nullptr);
+
+    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+    vk_subbuffer src_buf[6] = {};
+    for (int i = 0; i < 6; i++) {
+        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
+    }
+
+    const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));
+    const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));
+    const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));
+    const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));
+    const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));
+    const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));
+    const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float));
+    const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float));
+    const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float));
+
+    const uint32_t neq1 = (uint32_t)src_q->ne[1];
+    const uint32_t rq3  = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
+
+    const float scale = 1.0f / sqrtf((float)S_v);
+    const vk_op_gated_delta_net_push_constants pc = {
+        H, n_tokens, n_seqs, s_off,
+        sq1, sq2, sq3,
+        sv1, sv2, sv3,
+        sb1, sb2, sb3,
+        neq1, rq3,
+        scale
+    };
+
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+        {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
+        pc, { H, n_seqs, 1u });
+}
+
 static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
@@ -13044,6 +13142,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
 
+    case GGML_OP_GATED_DELTA_NET:
+        ggml_vk_gated_delta_net(ctx, compute_ctx, node);
+
+        break;
+
     case GGML_OP_SSM_SCAN:
         ggml_vk_ssm_scan(ctx, compute_ctx, node);
 
@@ -15455,6 +15558,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
             return true; // all inputs are contiguous, see ggml.c
+        case GGML_OP_GATED_DELTA_NET:
+            {
+                const uint32_t S_v = op->src[2]->ne[0];
+                if (S_v != 32 && S_v != 64 && S_v != 128) {
+                    return false;
+                }
+                for (int i = 0; i < 6; i++) {
+                    if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) {
+                        return false;
+                    }
+                }
+                return op->type == GGML_TYPE_F32;
+            }
         case GGML_OP_SSM_SCAN:
             {
                 for (int i = 0; i < 6; i++) {
@@ -16332,6 +16448,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
         } else if (tensor->op == GGML_OP_RWKV_WKV7) {
             tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
             src_clone[4], src_clone[5], src_clone[6]);
+        } else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
+            tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
+            src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
         } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
             src_clone[0]->flags = tensor->src[0]->flags;
             tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
diff --git a/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp
new file mode 100644 (file)
index 0000000..1fdf889
--- /dev/null
@@ -0,0 +1,128 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+layout(constant_id = 0) const uint S_V = 128;
+layout(constant_id = 1) const uint KDA = 0;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout(push_constant) uniform Parameters {
+    uint H;
+    uint n_tokens;
+    uint n_seqs;
+    uint s_off;
+    uint sq1, sq2, sq3;
+    uint sv1, sv2, sv3;
+    uint sb1, sb2, sb3;
+    uint neq1, rq3;
+    float scale;
+};
+
+layout(binding = 0) readonly  buffer QBuf     { FLOAT_TYPE data_q[];     };
+layout(binding = 1) readonly  buffer KBuf     { FLOAT_TYPE data_k[];     };
+layout(binding = 2) readonly  buffer VBuf     { FLOAT_TYPE data_v[];     };
+layout(binding = 3) readonly  buffer GBuf     { FLOAT_TYPE data_g[];     };
+layout(binding = 4) readonly  buffer BetaBuf  { FLOAT_TYPE data_beta[];  };
+layout(binding = 5) readonly  buffer StateBuf { FLOAT_TYPE data_state[]; };
+layout(binding = 6)           buffer DstBuf   { FLOAT_TYPE data_dst[];   };
+
+shared FLOAT_TYPE s_k[S_V];
+shared FLOAT_TYPE s_q[S_V];
+shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i])
+
+void main() {
+    const uint head_id = gl_WorkGroupID.x;
+    const uint seq_id  = gl_WorkGroupID.y;
+    const uint col     = gl_LocalInvocationID.x;
+
+    const uint iq1 = head_id % neq1;
+    const uint iq3 = seq_id / rq3;
+
+    const uint state_size = S_V * S_V;
+    const uint state_base = (seq_id * H + head_id) * state_size;
+
+    FLOAT_TYPE state[S_V];
+    [[unroll]] for (uint i = 0; i < S_V; i++) {
+        state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
+    }
+
+    uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
+
+    for (uint t = 0; t < n_tokens; t++) {
+        const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const uint k_off = q_off;
+        const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;
+
+        s_q[col] = FLOAT_TYPE(data_q[q_off + col]);
+        s_k[col] = FLOAT_TYPE(data_k[k_off + col]);
+
+        const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
+
+        if (KDA != 0) {
+            const uint g_base = gb_off * S_V;
+            s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col]));
+        }
+
+        barrier();
+
+        const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]);
+        const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]);
+
+        if (KDA == 0) {
+            const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off]));
+
+            FLOAT_TYPE kv_col = 0.0;
+            [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+                kv_col += dot(
+                    vec4(state[i], state[i+1], state[i+2], state[i+3]),
+                    vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3])
+                );
+            }
+
+            FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val;
+
+            FLOAT_TYPE attn_col = 0.0;
+            [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+                vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
+                vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
+                sv = g_val * sv + kv * delta_col;
+                state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
+
+                attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
+            }
+
+            data_dst[attn_off + col] = attn_col * scale;
+        } else {
+            FLOAT_TYPE kv_col = 0.0;
+            [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+                vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
+                vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
+                vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
+                kv_col += dot(gv * sv, kv);
+            }
+
+            FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val;
+
+            FLOAT_TYPE attn_col = 0.0;
+            [[unroll]] for (uint i = 0; i < S_V; i += 4) {
+                vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]);
+                vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]);
+                vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]);
+                sv = gv * sv + kv * delta_col;
+                state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w;
+
+                attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3]));
+            }
+
+            data_dst[attn_off + col] = attn_col * scale;
+        }
+
+        attn_off += S_V * H;
+        barrier();
+    }
+
+    [[unroll]] for (uint i = 0; i < S_V; i++) {
+        data_dst[s_off + state_base + i * S_V + col] = state[i];
+    }
+}
index fb8941232bc7fa211116096e5936f48282d3ca70..4b00ba3debb8091eb0c83c67c41d63d4bf031d5b 100644 (file)
@@ -987,6 +987,8 @@ void process_shaders() {
 
     string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
+    string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}}));
+
     string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
     string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
index a821655d10c7b28bd70d042ffd2b801d171a5dfc..e9f2e8ace469658521a384d52fe1e835b3f33303 100644 (file)
@@ -8731,6 +8731,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
     test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
 
+    // GATED_DELTA_NET: realistic model configurations
+    // TG: n_seq_tokens=1 (autoregressive)
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));   // Qwen3.5-like: 32 heads, d=128
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64,  1, 1));   // smaller model
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1, 1, false, true)); // KDA
+    // PP: n_seq_tokens=64,256 (prompt processing)
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1));  // PP-64
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 256, 1)); // PP-256
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 512, 1)); // PP-512
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1024, 1)); // PP-1024
+    // Small model configs (fewer heads = less GPU occupancy for autoregressive)
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 64, 1));   // 4h PP-64
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 256, 1));  // 4h PP-256
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 512, 1));  // 4h PP-512
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024
+    test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64
+
     return test_cases;
 }