"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=193,bs=[1,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=1056,n=1,k=67,bs=[1,1],nr=[4,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=64,n=77,k=77,bs=[12,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
-"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[128,1024],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
-"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=3,k=4,bs=[128,1024],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
-"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[131072,1],nr=[1,1],per=[0,2,1,3],k_v=0,o=1","support","1","yes","SYCL"
-"SYCL0","MUL_MAT","type_a=f16,type_b=f32,m=2,n=1,k=3,bs=[131072,1],nr=[1,1],per=[0,1,2,3],k_v=64,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=576,n=512,k=576,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","MUL_MAT","type_a=f32,type_b=f32,m=1,n=64,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1],stride_dim=-1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=-1","support","1","yes","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=-1","support","1","yes","SYCL"
-"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=1","support","1","yes","SYCL"
-"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=2","support","1","yes","SYCL"
+"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[256,16,2,3],stride_dim=1","support","0","no","SYCL"
+"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[128,16,2,3],stride_dim=2","support","0","no","SYCL"
"SYCL0","ACC","type=f32,ne_a=[256,17,2,3],ne_b=[64,16,2,3],stride_dim=3","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","SYCL"
"SYCL0","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","0","no","SYCL"
"SYCL0","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","SYCL"
"SYCL0","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","0","no","SYCL"
"SYCL0","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","0","no","SYCL"
-"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","0","no","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=32,head_size=16,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=16,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL"
+"SYCL0","GATED_DELTA_NET","type=f32,head_count=4,head_size=16,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","1","yes","SYCL"
--- /dev/null
+#include <sycl/sycl.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "ggml.h"
+#include "gated_delta_net.hpp"
+#include <cmath>
+
+
+template <int S_v, bool KDA>
+void gated_delta_net_sycl(const float * q,
+ const float * k,
+ const float * v,
+ const float * g,
+ const float * beta,
+ const float * curr_state,
+ float * dst,
+ int64_t H,
+ int64_t n_tokens,
+ int64_t n_seqs,
+ int64_t sq1,
+ int64_t sq2,
+ int64_t sq3,
+ int64_t sv1,
+ int64_t sv2,
+ int64_t sv3,
+ int64_t sb1,
+ int64_t sb2,
+ int64_t sb3,
+ const sycl::uint3 neqk1_magic,
+ const sycl::uint3 rq3_magic,
+ float scale) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const uint32_t h_idx = item_ct1.get_group(2);
+ const uint32_t sequence = item_ct1.get_group(1);
+ // each warp owns one column, using warp-level primitives to reduce across rows
+ const int lane = item_ct1.get_local_id(2);
+ const int col = item_ct1.get_group(0) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1);
+
+ const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
+ const uint32_t iq3 = fastdiv(sequence, rq3_magic);
+
+ const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
+ float * attn_data = dst;
+ float * state = dst + attn_score_elems;
+
+ const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v;
+ state += state_offset;
+ curr_state += state_offset;
+ attn_data += (sequence * n_tokens * H + h_idx) * S_v;
+
+ constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v;
+ static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
+ constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
+ float s_shard[rows_per_lane];
+#pragma unroll
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ s_shard[r] = curr_state[i * S_v + col];
+ }
+
+ for (int t = 0; t < n_tokens; t++) {
+ const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1;
+ const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1;
+ const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1;
+
+ const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1;
+ const float * beta_t = beta + gb_offset;
+ const float * g_t = g + gb_offset * (KDA ? S_v : 1);
+
+ const float beta_val = *beta_t;
+
+ if constexpr (!KDA) {
+ const float g_val = sycl::native::exp(*g_t);
+
+ // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i]
+ float kv_shard = 0.0f;
+#pragma unroll
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ kv_shard += s_shard[r] * k_t[i];
+ }
+ float kv_col = warp_reduce_sum<warp_size>(kv_shard);
+
+ // delta[col] = (v[col] - g * kv[col]) * beta
+ float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
+
+ // fused: S[i][col] = g * S[i][col] + k[i] * delta[col]
+ // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+ float attn_partial = 0.0f;
+#pragma unroll
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
+ attn_partial += s_shard[r] * q_t[i];
+ }
+
+ float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+ if (lane == 0) {
+ attn_data[col] = attn_col * scale;
+ }
+ } else {
+ // kv[col] = sum_i g[i] * S[i][col] * k[i]
+ float kv_shard = 0.0f;
+#pragma unroll
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ kv_shard += sycl::native::exp(g_t[i]) * s_shard[r] * k_t[i];
+ }
+
+ float kv_col = warp_reduce_sum<warp_size>(kv_shard);
+
+ // delta[col] = (v[col] - kv[col]) * beta
+ float delta_col = (v_t[col] - kv_col) * beta_val;
+
+ // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col]
+ // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i]
+ float attn_partial = 0.0f;
+#pragma unroll
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ s_shard[r] = sycl::native::exp(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
+ attn_partial += s_shard[r] * q_t[i];
+ }
+
+ float attn_col = warp_reduce_sum<warp_size>(attn_partial);
+
+ if (lane == 0) {
+ attn_data[col] = attn_col * scale;
+ }
+ }
+
+ attn_data += S_v * H;
+ }
+
+ // Write state back to global memory
+#pragma unroll
+ for (int r = 0; r < rows_per_lane; r++) {
+ const int i = r * warp_size + lane;
+ state[i * S_v + col] = s_shard[r];
+ }
+}
+
+template <bool KDA>
+static void launch_gated_delta_net(const float * q_d,
+ const float * k_d,
+ const float * v_d,
+ const float * g_d,
+ const float * b_d,
+ const float * s_d,
+ float * dst_d,
+ int64_t S_v,
+ int64_t H,
+ int64_t n_tokens,
+ int64_t n_seqs,
+ int64_t sq1,
+ int64_t sq2,
+ int64_t sq3,
+ int64_t sv1,
+ int64_t sv2,
+ int64_t sv3,
+ int64_t sb1,
+ int64_t sb2,
+ int64_t sb3,
+ int64_t neqk1,
+ int64_t rq3,
+ float scale,
+ dpct::queue_ptr stream) {
+ //TODO: Add chunked kernel for even faster pre-fill
+ const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size;
+
+ const int num_warps = 4;
+ dpct::dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
+ dpct::dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
+
+ const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1);
+ const sycl::uint3 rq3_magic = init_fastdiv_values(rq3);
+
+ int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
+
+ switch (S_v) {
+ case 16:
+ {
+ constexpr int sv = 16;
+ stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
+ n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
+ sb3, neqk1_magic, rq3_magic, scale);
+ });
+ }
+ break;
+ case 32:
+ {
+ constexpr int sv = 32;
+ stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ gated_delta_net_sycl<sv, KDA>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens,
+ n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2,
+ sb3, neqk1_magic, rq3_magic, scale);
+ });
+ }
+ break;
+ case 64: {
+ {
+ constexpr int sv = 64;
+ stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ gated_delta_net_sycl<sv, KDA>(
+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
+ sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+ });
+ }
+ break;
+ }
+ case 128: {
+ {
+ constexpr int sv = 128;
+ stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ gated_delta_net_sycl<sv, KDA>(
+ q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2,
+ sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
+ });
+ }
+ break;
+ }
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ ggml_tensor * src_q = dst->src[0];
+ ggml_tensor * src_k = dst->src[1];
+ ggml_tensor * src_v = dst->src[2];
+ ggml_tensor * src_g = dst->src[3];
+ ggml_tensor * src_beta = dst->src[4];
+ ggml_tensor * src_state = dst->src[5];
+
+ GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
+ GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
+ GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
+ GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
+ GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
+ GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
+ GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
+
+ const int64_t S_v = nev0;
+ const int64_t H = nev1;
+ const int64_t n_tokens = nev2;
+ const int64_t n_seqs = nev3;
+
+ const bool kda = (src_g->ne[0] == S_v);
+
+ GGML_ASSERT(neq1 == nek1);
+ const int64_t neqk1 = neq1;
+
+ const int64_t rq3 = nev3 / neq3;
+
+ const float * q_d = (const float *) src_q->data;
+ const float * k_d = (const float *) src_k->data;
+ const float * v_d = (const float *) src_v->data;
+ const float * g_d = (const float *) src_g->data;
+ const float * b_d = (const float *) src_beta->data;
+
+ const float * s_d = (const float *) src_state->data;
+ float * dst_d = (float *) dst->data;
+
+ GGML_ASSERT(ggml_is_contiguous_rows(src_q));
+ GGML_ASSERT(ggml_is_contiguous_rows(src_k));
+ GGML_ASSERT(ggml_is_contiguous_rows(src_v));
+ GGML_ASSERT(ggml_are_same_stride(src_q, src_k));
+ GGML_ASSERT(src_g->ne[0] == 1 || kda);
+ GGML_ASSERT(ggml_is_contiguous(src_g));
+ GGML_ASSERT(ggml_is_contiguous(src_beta));
+ GGML_ASSERT(ggml_is_contiguous(src_state));
+
+ // strides in floats (beta strides used for both g and beta offset computation)
+ const int64_t sq1 = nbq1 / sizeof(float);
+ const int64_t sq2 = nbq2 / sizeof(float);
+ const int64_t sq3 = nbq3 / sizeof(float);
+ const int64_t sv1 = nbv1 / sizeof(float);
+ const int64_t sv2 = nbv2 / sizeof(float);
+ const int64_t sv3 = nbv3 / sizeof(float);
+ const int64_t sb1 = nbb1 / sizeof(float);
+ const int64_t sb2 = nbb2 / sizeof(float);
+ const int64_t sb3 = nbb3 / sizeof(float);
+
+ const float scale = 1.0f / sqrtf((float) S_v);
+
+ dpct::queue_ptr stream = ctx.stream();
+
+ if (kda) {
+ launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
+ } else {
+ launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
+ S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
+ sb1, sb2, sb3, neqk1, rq3, scale, stream);
+ }
+}
+
+void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6);
+ ggml_sycl_op_gated_delta_net(ctx, dst);
+}