]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
add op gated_delta_net (llama/20455)
authorNeo Zhang <redacted>
Sat, 14 Mar 2026 14:01:57 +0000 (22:01 +0800)
committerGeorgi Gerganov <redacted>
Sun, 15 Mar 2026 19:50:13 +0000 (21:50 +0200)
src/ggml-sycl/common.hpp
src/ggml-sycl/gated_delta_net.cpp [new file with mode: 0644]
src/ggml-sycl/gated_delta_net.hpp [new file with mode: 0644]
src/ggml-sycl/ggml-sycl.cpp

index 9f0efb6535935eb2386744ab42094960ed979b17..fcb0db99c6b28990b2f02949b52e91c2e5bddadd 100644 (file)
@@ -211,7 +211,7 @@ struct sycl_device_info {
              // number of compute units on a SYCL device.
     // size_t  smpb;               // max. shared memory per block
     size_t  smpbo;              // max. shared memory per block (with opt-in)
-    int warp_size;     // max sub_group_size of SYCL
+    int warp_size;     // WARP_SIZE(16)|WARP_32_SIZE(32)|WARP_16_SIZE(16). For Intel GPU, 16 is better in most cases. Some OP support 32 only.
     int max_wg_per_cu; // max work groups per compute unit - refer to
                        // cudaOccupancyMaxActiveBlocksPerMultiprocessor
     bool    vmm;                // virtual memory support
diff --git a/src/ggml-sycl/gated_delta_net.cpp b/src/ggml-sycl/gated_delta_net.cpp
new file mode 100644 (file)
index 0000000..8c76afb
--- /dev/null
@@ -0,0 +1,309 @@
+#include <sycl/sycl.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "ggml.h"
+#include "gated_delta_net.hpp"
+#include <cmath>
+
+
+template <int S_v, bool KDA>
+void gated_delta_net_sycl(const float *     q,
+                          const float *     k,
+                          const float *     v,
+                          const float *     g,
+                          const float *     beta,
+                          const float *     curr_state,
+                          float *           dst,
+                          int64_t           H,
+                          int64_t           n_tokens,
+                          int64_t           n_seqs,
+                          int64_t           sq1,
+                          int64_t           sq2,
+                          int64_t           sq3,
+                          int64_t           sv1,
+                          int64_t           sv2,
+                          int64_t           sv3,
+                          int64_t           sb1,
+                          int64_t           sb2,
+                          int64_t           sb3,
+                          const sycl::uint3 neqk1_magic,
+                          const sycl::uint3 rq3_magic,
+                          float             scale) {
+    auto           item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const uint32_t h_idx    = item_ct1.get_group(2);
+    const uint32_t sequence = item_ct1.get_group(1);
+    // each warp owns one column, using warp-level primitives to reduce across rows
+    const int      lane     = item_ct1.get_local_id(2);
+    const int      col      = item_ct1.get_group(0) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
+
+    const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
+    const uint32_t iq3 = fastdiv(sequence, rq3_magic);
+
+    const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+    float *       attn_data        = dst;
+    float *       state            = dst + attn_score_elems;
+
+    const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
+    state += state_offset;
+    curr_state += state_offset;
+    attn_data += (sequence * n_tokens * H + h_idx) * S_v;
+
+    constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v;
+    static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
+    constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
+    float         s_shard[rows_per_lane];
+#pragma unroll
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i = r * warp_size + lane;
+        s_shard[r]  = curr_state[i * S_v + col];
+    }
+
+    for (int t = 0; t < n_tokens; t++) {
+        const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
+        const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
+
+        const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;
+        const float * beta_t = beta + gb_offset;
+        const float * g_t    = g    + gb_offset * (KDA ? S_v : 1);
+
+        const float beta_val = *beta_t;
+
+        if constexpr (!KDA) {
+            const float g_val = sycl::native::exp(*g_t);
+
+            // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
+            float kv_shard = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += s_shard[r] * k_t[i];
+            }
+            float kv_col = warp_reduce_sum<warp_size>(kv_shard);
+
+            // delta[col] = (v[col] - g * kv[col]) * beta
+            float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
+
+            // fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
+            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+            float attn_partial = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = g_val * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
+            }
+
+            float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
+        } else {
+            // kv[col] = sum_i g[i] * S[i][col] * k[i]
+            float kv_shard = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                kv_shard += sycl::native::exp(g_t[i]) * s_shard[r] * k_t[i];
+            }
+
+            float kv_col = warp_reduce_sum<warp_size>(kv_shard);
+
+            // delta[col] = (v[col] - kv[col]) * beta
+            float delta_col = (v_t[col] - kv_col) * beta_val;
+
+            // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
+            // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+            float attn_partial = 0.0f;
+#pragma unroll
+            for (int r = 0; r < rows_per_lane; r++) {
+                const int i = r * warp_size + lane;
+                s_shard[r]  = sycl::native::exp(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
+                attn_partial += s_shard[r] * q_t[i];
+            }
+
+            float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+            if (lane == 0) {
+                attn_data[col] = attn_col * scale;
+            }
+        }
+
+        attn_data += S_v * H;
+    }
+
+    // Write state back to global memory
+#pragma unroll
+    for (int r = 0; r < rows_per_lane; r++) {
+        const int i          = r * warp_size + lane;
+        state[i * S_v + col] = s_shard[r];
+    }
+}
+
+template <bool KDA>
+static void launch_gated_delta_net(const float *   q_d,
+                                   const float *   k_d,
+                                   const float *   v_d,
+                                   const float *   g_d,
+                                   const float *   b_d,
+                                   const float *   s_d,
+                                   float *         dst_d,
+                                   int64_t         S_v,
+                                   int64_t         H,
+                                   int64_t         n_tokens,
+                                   int64_t         n_seqs,
+                                   int64_t         sq1,
+                                   int64_t         sq2,
+                                   int64_t         sq3,
+                                   int64_t         sv1,
+                                   int64_t         sv2,
+                                   int64_t         sv3,
+                                   int64_t         sb1,
+                                   int64_t         sb2,
+                                   int64_t         sb3,
+                                   int64_t         neqk1,
+                                   int64_t         rq3,
+                                   float           scale,
+                                   dpct::queue_ptr stream) {
+    //TODO: Add chunked kernel for even faster pre-fill
+    const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size;
+
+    const int num_warps = 4;
+    dpct::dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
+    dpct::dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
+
+    const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1);
+    const sycl::uint3 rq3_magic   = init_fastdiv_values(rq3);
+
+    int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
+
+    switch (S_v) {
+        case 16:
+            {
+                constexpr int sv = 16;
+                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                                     [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                         gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
+                                                                       n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
+                                                                       sb3, neqk1_magic, rq3_magic, scale);
+                                     });
+            }
+            break;
+        case 32:
+            {
+                constexpr int sv = 32;
+                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                                     [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                         gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
+                                                                       n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
+                                                                       sb3, neqk1_magic, rq3_magic, scale);
+                                     });
+            }
+            break;
+        case 64: {
+            {
+                constexpr int sv = 64;
+                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                                        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                            gated_delta_net_sycl<sv, KDA>(
+                                                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
+                                                sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+                                        });
+            }
+            break;
+        }
+        case 128: {
+            {
+                constexpr int sv = 128;
+                stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                                        [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                            gated_delta_net_sycl<sv, KDA>(
+                                                q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
+                                                sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+                                        });
+            }
+            break;
+        }
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * src_q     = dst->src[0];
+    ggml_tensor * src_k     = dst->src[1];
+    ggml_tensor * src_v     = dst->src[2];
+    ggml_tensor * src_g     = dst->src[3];
+    ggml_tensor * src_beta  = dst->src[4];
+    ggml_tensor * src_state = dst->src[5];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
+    GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
+    GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
+    GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
+    GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
+    GGML_TENSOR_LOCALS(size_t,  nbv, src_v, nb);
+    GGML_TENSOR_LOCALS(size_t,  nbb, src_beta, nb);
+
+    const int64_t S_v      = nev0;
+    const int64_t H        = nev1;
+    const int64_t n_tokens = nev2;
+    const int64_t n_seqs   = nev3;
+
+    const bool kda = (src_g->ne[0] == S_v);
+
+    GGML_ASSERT(neq1 == nek1);
+    const int64_t neqk1 = neq1;
+
+    const int64_t rq3 = nev3 / neq3;
+
+    const float * q_d = (const float *) src_q->data;
+    const float * k_d = (const float *) src_k->data;
+    const float * v_d = (const float *) src_v->data;
+    const float * g_d = (const float *) src_g->data;
+    const float * b_d = (const float *) src_beta->data;
+
+    const float * s_d   = (const float *) src_state->data;
+    float *       dst_d = (float *) dst->data;
+
+    GGML_ASSERT(ggml_is_contiguous_rows(src_q));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_k));
+    GGML_ASSERT(ggml_is_contiguous_rows(src_v));
+    GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
+    GGML_ASSERT(src_g->ne[0] == 1 || kda);
+    GGML_ASSERT(ggml_is_contiguous(src_g));
+    GGML_ASSERT(ggml_is_contiguous(src_beta));
+    GGML_ASSERT(ggml_is_contiguous(src_state));
+
+    // strides in floats (beta strides used for both g and beta offset computation)
+    const int64_t sq1 = nbq1 / sizeof(float);
+    const int64_t sq2 = nbq2 / sizeof(float);
+    const int64_t sq3 = nbq3 / sizeof(float);
+    const int64_t sv1 = nbv1 / sizeof(float);
+    const int64_t sv2 = nbv2 / sizeof(float);
+    const int64_t sv3 = nbv3 / sizeof(float);
+    const int64_t sb1 = nbb1 / sizeof(float);
+    const int64_t sb2 = nbb2 / sizeof(float);
+    const int64_t sb3 = nbb3 / sizeof(float);
+
+    const float scale = 1.0f / sqrtf((float) S_v);
+
+    dpct::queue_ptr stream = ctx.stream();
+
+    if (kda) {
+        launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
+    } else {
+        launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+            S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+            sb1, sb2, sb3, neqk1, rq3, scale, stream);
+    }
+}
+
+void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6);
+    ggml_sycl_op_gated_delta_net(ctx, dst);
+}
diff --git a/src/ggml-sycl/gated_delta_net.hpp b/src/ggml-sycl/gated_delta_net.hpp
new file mode 100644 (file)
index 0000000..a3308ee
--- /dev/null
@@ -0,0 +1,8 @@
+#pragma once
+
+#include <sycl/sycl.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "ggml.h"
+
+void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
index f887061b279b5c0f6aecf092cb8681e5617e221a..128197058490d1e72d89045988b2a37bb6162b55 100644 (file)
@@ -35,6 +35,7 @@
 #endif
 #include <sycl/half_type.hpp>
 
+#include "ggml.h"
 #include "ggml-sycl.h"
 #include "ggml-impl.h"
 #include "ggml-backend-impl.h"
 #include "ggml-sycl/backend.hpp"
 #include "ggml-sycl/common.hpp"
 #include "ggml-sycl/element_wise.hpp"
+#include "ggml-sycl/gated_delta_net.hpp"
+#include "ggml-sycl/gemm.hpp"
+#include "ggml-sycl/getrows.hpp"
 #include "ggml-sycl/norm.hpp"
 #include "ggml-sycl/presets.hpp"
-#include "ggml-sycl/gemm.hpp"
+#include "ggml-sycl/quantize.hpp"
+#include "ggml-sycl/repeat_back.hpp"
 #include "ggml-sycl/set_rows.hpp"
 #include "ggml-sycl/set.hpp"
-#include "ggml-sycl/sycl_hw.hpp"
-#include "ggml-sycl/getrows.hpp"
-#include "ggml-sycl/repeat_back.hpp"
-#include "ggml-sycl/quantize.hpp"
 #include "ggml-sycl/ssm_conv.hpp"
-#include "ggml.h"
+#include "ggml-sycl/sycl_hw.hpp"
+
 
 static bool g_sycl_loaded = false;
 int g_ggml_sycl_debug = 0;
@@ -99,6 +101,8 @@ static ggml_sycl_device_info ggml_sycl_init() {
         info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
         info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
         info.devices[i].smpbo = prop.get_local_mem_size();
+        info.devices[i].warp_size = WARP_SIZE;
+
         info.max_work_group_sizes[i] = prop.get_max_work_group_size();
         info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
 
@@ -4181,6 +4185,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_GATED_LINEAR_ATTN:
             ggml_sycl_op_gated_linear_attn(ctx, dst);
             break;
+        case GGML_OP_GATED_DELTA_NET:
+            ggml_sycl_gated_delta_net(ctx, dst);
+            break;
         case GGML_OP_SSM_CONV:
             ggml_sycl_ssm_conv(ctx, dst);
             break;
@@ -4890,6 +4897,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
         case GGML_OP_GATED_LINEAR_ATTN:
+        case GGML_OP_GATED_DELTA_NET:
             return true;
         case GGML_OP_SSM_CONV:
             return op->type == GGML_TYPE_F32 &&