]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration (llama...
authorZhiyuan Li <redacted>
Thu, 7 Nov 2024 07:19:10 +0000 (18:19 +1100)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
* rwkv6: rename to wkv6

* rwkv6: support avx2 avx512 armv8 armv9

* rwkv6: update cuda file name

* rwkv6: rename params

* wkv on sycl

* sycl: add some ops

* sycl: Enhance OP support judgment

* wkv6: drop armv9 and tranfer to GGML style

ggml-ci

* sync : ggml

* update the function to use appropriate types

* fix define error

* Update ggml/src/ggml-cpu.c

* add appropriate asserts

* move element-wise functions outside

* put the declaration outside the loop

* rewrite to be more inline with the common pattern for distributing threads

* use recommended way GGML_TENSOR_LOCALS

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Diego Devesa <redacted>
Co-authored-by: Plamen Minev <redacted>
Co-authored-by: Yuri Khrustalev <redacted>
Co-authored-by: Meng, Hengyu <redacted>
18 files changed:
ggml/include/ggml.h
ggml/src/ggml-cpu.c
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/wkv6.cu [new file with mode: 0644]
ggml/src/ggml-cuda/wkv6.cuh [new file with mode: 0644]
ggml/src/ggml-sycl.cpp
ggml/src/ggml-sycl/backend.hpp
ggml/src/ggml-sycl/common.cpp
ggml/src/ggml-sycl/common.hpp
ggml/src/ggml-sycl/concat.cpp
ggml/src/ggml-sycl/element_wise.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/element_wise.hpp [new file with mode: 0644]
ggml/src/ggml-sycl/outprod.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/outprod.hpp [new file with mode: 0644]
ggml/src/ggml-sycl/presets.hpp
ggml/src/ggml-sycl/wkv6.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/wkv6.hpp [new file with mode: 0644]
ggml/src/ggml.c

index 8a0bcbff8c61af27d0d2057dd06e58b2a1acbd9d..0d143d2fe0a20103e8feb99835bf3860fdd09a1c 100644 (file)
@@ -509,7 +509,7 @@ extern "C" {
         GGML_OP_WIN_UNPART,
         GGML_OP_GET_REL_POS,
         GGML_OP_ADD_REL_POS,
-        GGML_OP_RWKV_WKV,
+        GGML_OP_RWKV_WKV6,
 
         GGML_OP_UNARY,
 
@@ -1819,7 +1819,7 @@ extern "C" {
             struct ggml_tensor  * pw,
             struct ggml_tensor  * ph);
 
-    GGML_API struct ggml_tensor * ggml_rwkv_wkv(
+    GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
             struct ggml_context * ctx,
             struct ggml_tensor  * k,
             struct ggml_tensor  * v,
index 0cb5b824afc05a900133cebe1a1857491fe8438a..98c3e21ae3fcff3ae2ab8e132a766cc64f6fdf5d 100644 (file)
@@ -11642,24 +11642,30 @@ static void ggml_compute_forward_add_rel_pos(
     }
 }
 
-// ggml_compute_forward_rwkv_wkv
+// ggml_compute_forward_rwkv_wkv6
 
-static void ggml_compute_forward_rwkv_wkv_f32(
+static void ggml_compute_forward_rwkv_wkv6_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    const size_t T = dst->src[1]->ne[3];
-    const size_t C = dst->ne[0];
-    const size_t H = dst->src[1]->ne[2];
-    const size_t n_seqs = dst->src[5]->ne[1];
+    const int64_t T = dst->src[1]->ne[3];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[2];
+    const int64_t n_seqs = dst->src[5]->ne[1];
+    const int64_t head_size = C / HEADS;
 
     float * dst_data = (float *) dst->data;
     float * state = ((float *) dst->data) + C * T;
 
-    if (params->ith != 0) {
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
         return;
     }
 
-    memset(dst_data, 0, T * C * sizeof(float));
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
 
     float * k =          (float *) dst->src[0]->data;
     float * v =          (float *) dst->src[1]->data;
@@ -11667,54 +11673,160 @@ static void ggml_compute_forward_rwkv_wkv_f32(
     float * time_faaaa = (float *) dst->src[3]->data;
     float * time_decay = (float *) dst->src[4]->data;
 
-    size_t t_stride = H * (C / H);
+    size_t t_stride = HEADS * head_size; // Same to C
 
-    size_t h_stride = C / H;
-    size_t h_stride_2d = (C / H) * (C / H);
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
 
-    // basically fused operations:
-    // dst = r @ (time_faaaa * (k @ v) + state),
-    // state = time_decay * state + (k @ v),
-    // recursive through each token
-    for (size_t t = 0; t < T; t++) {
-        size_t t_offset = t * t_stride;
-        size_t state_offset = (C / H) * C * (t / (T / n_seqs));
-        float * state_cur = state + state_offset;
-        float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
 
-        for (size_t h = 0; h < H; h++) {
-            size_t h_offset = h * h_stride;
-            size_t t_h_offset = t_offset + h_offset;
-            size_t h_2d_offset = h * h_stride_2d;
 
-            for (size_t i = 0; i < C / H; i++) {
-                size_t t_h_i_offset = t_h_offset + i;
-                size_t h_i_offset = h_offset + i;
-                size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define WKV_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define WKV_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define WKV_VECTOR_SIZE 4
+    #endif
 
-                float k_val = k[t_h_i_offset];
-                float r_val = r[t_h_i_offset];
-                float time_faaaa_val = time_faaaa[h_i_offset];
-                // RWKV v6: different time_decay for each token.
-                float time_decay_val = time_decay[t_h_i_offset];
+    #ifdef WKV_VECTOR_SIZE
+        const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_i_offset = h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float r_val = r[t_h_i_offset];
+                    float time_faaaa_val = time_faaaa[h_i_offset];
+                    float time_decay_val = time_decay[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
+                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
+                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * WKV_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = kv * time_faaaa + prev_state
+                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
+
+                        // Update dst: dst += temp * r
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state: state = prev_state * time_decay + kv
+                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
+                    }
 
-                for (size_t j = 0; j < C / H; j ++) {
-                    size_t t_h_j_offset = t_h_offset + j;
-                    size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                        dst_data[t_h_j_offset] += temp_val * r_val;
+                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                    }
+                }
+            }
+        }
 
-                    float v_val = v[t_h_j_offset];
-                    float kv_val = v_val * k_val;
-                    float prev_state_val = state_prev[h_2d_i_j_offset];
-                    float temp_val = kv_val * time_faaaa_val + prev_state_val;
-                    dst_data[t_h_j_offset] += temp_val * r_val;
-                    state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+    #else
+        // basically fused operations:
+        // dst = r @ (time_faaaa * (k @ v) + state),
+        // state = time_decay * state + (k @ v),
+        // recursive through each token
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_i_offset = h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float r_val = r[t_h_i_offset];
+                    float time_faaaa_val = time_faaaa[h_i_offset];
+                    // RWKV v6: different time_decay for each token.
+                    float time_decay_val = time_decay[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
+                        dst_data[t_h_j_offset] += temp_val * r_val;
+                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
+                    }
                 }
             }
         }
-    }
+    #endif
 }
 
-static void ggml_compute_forward_rwkv_wkv(
+
+static void ggml_compute_forward_rwkv_wkv6(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
@@ -11723,7 +11835,7 @@ static void ggml_compute_forward_rwkv_wkv(
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rwkv_wkv_f32(params, dst);
+                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
             } break;
         default:
             {
@@ -12475,9 +12587,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_add_rel_pos(params, tensor);
             } break;
-        case GGML_OP_RWKV_WKV:
+        case GGML_OP_RWKV_WKV6:
             {
-                ggml_compute_forward_rwkv_wkv(params, tensor);
+                ggml_compute_forward_rwkv_wkv6(params, tensor);
             } break;
         case GGML_OP_MAP_UNARY:
             {
@@ -12775,7 +12887,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_PART:
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
-        case GGML_OP_RWKV_WKV:
+        case GGML_OP_RWKV_WKV6:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
index e68e405501bedda0eb35c527b66f4103f8d99b14..e27c8e87d5023ab4e45fb4ed4c13d1f1d69e15b1 100644 (file)
@@ -36,7 +36,7 @@
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
-#include "ggml-cuda/rwkv-wkv.cuh"
+#include "ggml-cuda/wkv6.cuh"
 
 #include <algorithm>
 #include <array>
@@ -2319,8 +2319,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_CROSS_ENTROPY_LOSS:
             ggml_cuda_cross_entropy_loss(ctx, dst);
             break;
-        case GGML_OP_RWKV_WKV:
-            ggml_cuda_op_rwkv_wkv(ctx, dst);
+        case GGML_OP_RWKV_WKV6:
+            ggml_cuda_op_rwkv_wkv6(ctx, dst);
             break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
@@ -3153,7 +3153,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
-        case GGML_OP_RWKV_WKV:
+        case GGML_OP_RWKV_WKV6:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
diff --git a/ggml/src/ggml-cuda/wkv6.cu b/ggml/src/ggml-cuda/wkv6.cu
new file mode 100644 (file)
index 0000000..4257834
--- /dev/null
@@ -0,0 +1,89 @@
+#include "common.cuh"
+#include "wkv6.cuh"
+
+static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = CUDA_WKV_BLOCK_SIZE;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    __syncthreads();
+    _tf[tid] = tf[head_i * head_size + tid];
+    __syncthreads();
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4& k = (float4&)(_k[j]);
+            const float4& r = (float4&)(_r[j]);
+            const float4& tf = (float4&)(_tf[j]);
+            const float4& td = (float4&)(_td[j]);
+            float4& s = (float4&)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            y += r.x * (tf.x * kv.x + s.x);
+            y += r.y * (tf.y * kv.y + s.y);
+            y += r.z * (tf.z * kv.z + s.z);
+            y += r.w * (tf.w * kv.w + s.w);
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+        }
+        dst[t] = y;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * tf_d = (const float *)dst->src[3]->data;
+    const float * td_d = (const float *)dst->src[4]->data;
+    const float * s_d  = (const float *)dst->src[5]->data;
+
+    const int64_t B = dst->src[5]->ne[1];
+    const int64_t T = dst->src[0]->ne[3];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[2];
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
+
+    rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+}
diff --git a/ggml/src/ggml-cuda/wkv6.cuh b/ggml/src/ggml-cuda/wkv6.cuh
new file mode 100644 (file)
index 0000000..a7124ee
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_WKV_BLOCK_SIZE 64
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index a62c67f4f1ceca2d147096a340ee1faa0f75e1c7..255bc64c6baddf13adb1d8e2a9820b88036c0e6a 100644 (file)
@@ -1194,272 +1194,8 @@ typedef void (*ggml_sycl_op_mul_mat_t)(
     float *dst_dd_i, const int64_t row_low, const int64_t row_high,
     const int64_t src1_ncols, const int64_t src1_padded_row_size,
     const queue_ptr &stream);
-typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                       const ggml_tensor *src1,
-                                       ggml_tensor *dst, const float *src0_dd,
-                                       const float *src1_dd, float *dst_dd,
-                                       const queue_ptr &main_stream);
-
-static __dpct_inline__ float op_repeat(const float a, const float b) {
-    return b;
-    GGML_UNUSED(a);
-}
-
-static __dpct_inline__ float op_add(const float a, const float b) {
-    return a + b;
-}
-
-static __dpct_inline__ float op_mul(const float a, const float b) {
-    return a * b;
-}
-
-static __dpct_inline__ float op_div(const float a, const float b) {
-    return a / b;
-}
-
-template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
-static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        int ne0, int ne1, int ne2, int ne3,
-        int ne10, int ne11, int ne12, int ne13,
-        /*int s0, */ int s1,  int s2,  int s3,
-        /*int s10,*/ int s11, int s12, int s13,
-        const sycl::nd_item<3> &item_ct1) {
-    const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-    const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1));
-    const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                    item_ct1.get_local_id(0)) /
-                   ne3;
-    const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                    item_ct1.get_local_id(0)) %
-                   ne3;
-
-    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
-        return;
-    }
-
-    const int i11 = i1 % ne11;
-    const int i12 = i2 % ne12;
-    const int i13 = i3 % ne13;
-
-    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
-    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  = i_src0;
-
-    const src0_t * src0_row = src0 + i_src0;
-    const src1_t * src1_row = src1 + i_src1;
-    dst_t * dst_row = dst + i_dst;
-
-    for (int i0 = i0s; i0 < ne0;
-         i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
-        const int i10 = i0 % ne10;
-        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
-    }
-}
 
-template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
-static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        int ne0, int ne1, int ne2, int ne3,
-        int ne10, int ne11, int ne12, int ne13,
-        /*int s0, */ int s1,  int s2,  int s3,
-        /*int s10,*/ int s11, int s12, int s13,
-        const sycl::nd_item<3> &item_ct1) {
 
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    const int i3 = i/(ne2*ne1*ne0);
-    const int i2 = (i/(ne1*ne0)) % ne2;
-    const int i1 = (i/ne0) % ne1;
-    const int i0 = i % ne0;
-
-    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
-        return;
-    }
-
-    const int i11 = i1 % ne11;
-    const int i12 = i2 % ne12;
-    const int i13 = i3 % ne13;
-
-    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
-    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  = i_src0;
-
-    const src0_t * src0_row = src0 + i_src0;
-    const src1_t * src1_row = src1 + i_src1;
-    dst_t * dst_row = dst + i_dst;
-
-    const int i10 = i0 % ne10;
-    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
-}
-
-static void acc_f32(const float * x, const float * y, float * dst, const int ne,
-    const int ne10, const int ne11, const int ne12,
-    const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= ne) {
-        return;
-    }
-    int src1_idx = i - offset;
-    int oz = src1_idx / nb2;
-    int oy = (src1_idx - (oz * nb2)) / nb1;
-    int ox = src1_idx % nb1;
-    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
-        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
-    } else {
-        dst[i] = x[i];
-    }
-}
-
-static void gelu_f32(const float * x, float * dst, const int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const float GELU_COEF_A    = 0.044715f;
-    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-
-    float xi = x[i];
-    dst[i] = 0.5f * xi *
-             (1.0f +
-              sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
-}
-
-static void silu_f32(const float * x, float * dst, const int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
-}
-
-static void gelu_quick_f32(const float *x, float *dst, int k,
-                           const sycl::nd_item<3> &item_ct1) {
-    const float GELU_QUICK_COEF = -1.702f;
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
-}
-
-static void tanh_f32(const float *x, float *dst, int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::tanh((float)(x[i]));
-}
-
-static void relu_f32(const float * x, float * dst, const int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::fmax((float)(x[i]), (float)0);
-}
-
-static void hardsigmoid_f32(const float * x, float * dst, const int k,
-                            const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
-}
-
-static void hardswish_f32(const float * x, float * dst, const int k,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
-}
-
-static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
-                           const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::fmax((float)(x[i]), (float)0) +
-             sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
-}
-
-static void sqr_f32(const float * x, float * dst, const int k,
-                    const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] * x[i];
-}
-
-static void upscale_f32(const float  *x, float *dst, const int nb00, const int nb01,
-                        const int nb02, const int nb03, const int ne10, const int ne11,
-                        const int ne12, const int ne13, const float sf0, const float sf1,
-                        const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
-    int index = item_ct1.get_local_id(0) +
-               item_ct1.get_group(0) * item_ct1.get_local_range(0);
-    if (index >= ne10 * ne11 * ne12 * ne13) {
-        return;
-    }
-    // operation
-    int i10 = index % ne10;
-    int i11 = (index / ne10) % ne11;
-    int i12 = (index / (ne10 * ne11)) % ne12;
-    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
-
-    int i00 = i10 / sf0;
-    int i01 = i11 / sf1;
-    int i02 = i12 / sf2;
-    int i03 = i13 / sf3;
-
-    dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
-}
-
-static void pad_f32(const float  *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
-                    const sycl::nd_item<3> &item_ct1) {
-    int nidx = item_ct1.get_local_id(2) +
-               item_ct1.get_group(2) * item_ct1.get_local_range(2);
-    if (nidx >= ne0) {
-        return;
-    }
-
-    // operation
-    int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
-                     item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
-    if (nidx < ne00 && item_ct1.get_group(1) < ne01 &&
-        item_ct1.get_group(0) < ne02) {
-        int offset_src = nidx + item_ct1.get_group(1) * ne00 +
-                         item_ct1.get_group(0) * ne00 * ne01;
-            dst[offset_dst] = x[offset_src];
-    } else {
-        dst[offset_dst] = 0.0f;
-    }
-}
 
 template<int QUANT_BLOCK_TILE>
 static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
@@ -2148,297 +1884,6 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
     (void) dst;
 }
 
-template<float (*bin_op)(const float, const float)>
-struct bin_bcast_sycl {
-    template <typename src0_t, typename src1_t, typename dst_t>
-    void operator()(ggml_backend_sycl_context & ctx,
-                    const struct ggml_tensor *src0,
-                    const struct ggml_tensor *src1, struct ggml_tensor *dst,
-                    const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
-                    queue_ptr stream) {
-
-        GGML_TENSOR_BINARY_OP_LOCALS
-
-        int nr0 = ne10/ne0;
-        int nr1 = ne11/ne1;
-        int nr2 = ne12/ne2;
-        int nr3 = ne13/ne3;
-
-        int nr[4] = { nr0, nr1, nr2, nr3 };
-
-        // collapse dimensions until first broadcast dimension
-        int64_t cne0[] = {ne0, ne1, ne2, ne3};
-        int64_t cne1[] = {ne10, ne11, ne12, ne13};
-        size_t cnb0[] = {nb0, nb1, nb2, nb3};
-        size_t cnb1[] = {nb10, nb11, nb12, nb13};
-        auto collapse = [](int64_t cne[]) {
-            cne[0] *= cne[1];
-            cne[1] = cne[2];
-            cne[2] = cne[3];
-            cne[3] = 1;
-        };
-
-        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
-            cnb[1] *= cne[1];
-            cnb[2] *= cne[2];
-            cnb[3] *= cne[3];
-        };
-
-        for (int i = 0; i < 4; i++) {
-            if (nr[i] != 1) {
-                break;
-            }
-            if (i > 0) {
-                collapse_nb(cnb0, cne0);
-                collapse_nb(cnb1, cne1);
-                collapse(cne0);
-                collapse(cne1);
-            }
-        }
-        {
-            int64_t ne0 = cne0[0];
-            int64_t ne1 = cne0[1];
-            int64_t ne2 = cne0[2];
-            int64_t ne3 = cne0[3];
-
-            int64_t ne10 = cne1[0];
-            int64_t ne11 = cne1[1];
-            int64_t ne12 = cne1[2];
-            int64_t ne13 = cne1[3];
-
-            size_t nb0 = cnb0[0];
-            size_t nb1 = cnb0[1];
-            size_t nb2 = cnb0[2];
-            size_t nb3 = cnb0[3];
-
-            size_t nb10 = cnb1[0];
-            size_t nb11 = cnb1[1];
-            size_t nb12 = cnb1[2];
-            size_t nb13 = cnb1[3];
-
-            size_t s0 = nb0 / sizeof(dst_t);
-            size_t s1 = nb1 / sizeof(dst_t);
-            size_t s2 = nb2 / sizeof(dst_t);
-            size_t s3 = nb3 / sizeof(dst_t);
-
-            size_t s10 = nb10 / sizeof(src1_t);
-            size_t s11 = nb11 / sizeof(src1_t);
-            size_t s12 = nb12 / sizeof(src1_t);
-            size_t s13 = nb13 / sizeof(src1_t);
-
-            GGML_ASSERT(s0 == 1);
-            GGML_ASSERT(s10 == 1);
-
-            const int block_size = 128;
-
-            int64_t hne0 = std::max(ne0/2LL, 1LL);
-
-            sycl::range<3> block_dims(1, 1, 1);
-            block_dims[2] = std::min<unsigned int>(hne0, block_size);
-            block_dims[1] = std::min<unsigned int>(
-                ne1, block_size / (unsigned int)block_dims[2]);
-            block_dims[0] = std::min(
-                std::min<unsigned int>(
-                    ne2 * ne3, block_size / (unsigned int)block_dims[2] /
-                                   (unsigned int)block_dims[1]),
-                64U);
-
-            sycl::range<3> block_nums(
-                (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
-                (ne1 + block_dims[1] - 1) / block_dims[1],
-                (hne0 + block_dims[2] - 1) / block_dims[2]);
-
-            if (block_nums[0] > 65535) {
-                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
-                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
-                {
-                    dpct::has_capability_or_fail(stream->get_device(),
-                                                 {sycl::aspect::fp16});
-
-                    stream->parallel_for(
-                        sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
-                                              sycl::range<3>(1, 1, block_size),
-                                          sycl::range<3>(1, 1, block_size)),
-                        [=](sycl::nd_item<3> item_ct1) {
-                            k_bin_bcast_unravel<bin_op>(
-                                src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
-                                ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
-                                s13, item_ct1);
-                        });
-                }
-            } else {
-                /*
-                DPCT1049:16: The work-group size passed to the SYCL kernel may
-                exceed the limit. To get the device limit, query
-                info::device::max_work_group_size. Adjust the work-group size if
-                needed.
-                */
-                dpct::has_capability_or_fail(stream->get_device(),
-                                             {sycl::aspect::fp16});
-
-                stream->parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
-                                            ne2, ne3, ne10, ne11, ne12, ne13,
-                                            s1, s2, s3, s11, s12, s13,
-                                            item_ct1);
-                    });
-            }
-        }
-    }
-};
-
-static void acc_f32_sycl(const float *x, const float *y, float *dst,
-                         const int n_elements, const int ne10, const int ne11,
-                         const int ne12, const int nb1, const int nb2,
-                         const int offset, queue_ptr stream) {
-    int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
-                    item_ct1);
-        });
-}
-
-static void gelu_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            gelu_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void silu_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            silu_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
-                                queue_ptr stream) {
-    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            gelu_quick_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void tanh_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            tanh_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void relu_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            relu_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
-                                 queue_ptr stream) {
-    const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            hardsigmoid_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void hardswish_f32_sycl(const float *x, float *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            hardswish_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
-                                const float negative_slope,
-                                queue_ptr stream) {
-    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
-        });
-}
-
-static void sqr_f32_sycl(const float *x, float *dst, const int k,
-                         queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            sqr_f32(x, dst, k, item_ct1);
-        });
-}
-
-static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
-                             const int nb02, const int nb03, const int ne10, const int ne11,
-                             const int ne12, const int ne13, const float sf0, const float sf1,
-                             const float sf2, const float sf3, queue_ptr stream) {
-    int dst_size = ne10 * ne11 * ne12 * ne13;
-    int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
-    sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
-    stream->parallel_for(
-        sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
-        [=](sycl::nd_item<1> item_ct1) {
-            upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
-        });
-}
-
-static void pad_f32_sycl(const float *x, float *dst, const int ne00,
-                         const int ne01, const int ne02, const int ne0,
-                         const int ne1, const int ne2, queue_ptr stream) {
-    int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
-    sycl::range<3> gridDim(ne2, ne1, num_blocks);
-    stream->parallel_for(
-        sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
-        });
-}
 
 static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
                                    const int ky, const int kx_padded,
@@ -2816,6 +2261,58 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
     }
 }
 
+static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
+                               const int nrows, queue_ptr stream) {
+    const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE);
+    const sycl::range<3> block_nums(1, nrows, 1);
+    const size_t shared_mem = 256 * sizeof(float);
+
+    stream->submit([&](sycl::handler &cgh) {
+        sycl::local_accessor<float, 1> shared_data(
+            sycl::range<1>(shared_mem/sizeof(float)), cgh);
+        sycl::local_accessor<int, 1> shared_indices(
+            sycl::range<1>(shared_mem/sizeof(float)), cgh);
+
+        cgh.parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                const int tid = item_ct1.get_local_id(2);
+                const int row = item_ct1.get_global_id(1);
+
+                float max_val = -INFINITY;
+                int max_idx = -1;
+
+                for (int col = tid; col < ncols; col += 256) {
+                    float val = x[row * ncols + col];
+                    if (val > max_val) {
+                        max_val = val;
+                        max_idx = col;
+                    }
+                }
+
+                shared_data[tid] = max_val;
+                shared_indices[tid] = max_idx;
+                item_ct1.barrier(sycl::access::fence_space::local_space);
+
+                for (int stride = 256/2; stride > 0; stride >>= 1) {
+                    if (tid < stride) {
+                        float val1 = shared_data[tid];
+                        float val2 = shared_data[tid + stride];
+                        if (val2 > val1) {
+                            shared_data[tid] = val2;
+                            shared_indices[tid] = shared_indices[tid + stride];
+                        }
+                    }
+                    item_ct1.barrier(sycl::access::fence_space::local_space);
+                }
+
+
+                if (tid == 0) {
+                    dst[row] = shared_indices[0];
+                }
+            });
+    });
+}
 static void diag_mask_inf_f32_sycl(const float *x, float *dst,
                                    const int ncols_x, const int nrows_x,
                                    const int rows_per_channel, const int n_past,
@@ -2855,362 +2352,111 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
     } else {
         // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
         GGML_ABORT("fatal error");
-    }
-    char * dst_ptr = (char *) dst;
-
-    GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
-    GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
-    const enum ggml_type type = src->type;
-    const int64_t ts = ggml_type_size(type);
-    const int64_t bs = ggml_blck_size(type);
-    int64_t i1_diff = i1_high - i1_low;
-
-    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
-    if (nb0 == ts && nb1 == ts*ne0/bs) {
-        // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
-        // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
-        return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
-                                    kind, *stream));
-
-    } else if (nb0 == ts) {
-        return CHECK_TRY_ERROR(
-            dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
-                                    ts * ne0 / bs, i1_diff, kind, *stream));
-    } else {
-        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
-            const void * rx = (const void *) ((const char *) x + i1*nb1);
-            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
-            // pretend the row is a matrix with cols=1
-            dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
-                rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
-            /*
-            DPCT1001:85: The statement could not be removed.
-            */
-            /*
-            DPCT1000:86: Error handling if-stmt was detected but could not be
-            rewritten.
-            */
-            if (r != 0) return r;
-        }
-        return 0;
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
-static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_d, const float *src1_d,
-                                  float *dst_d, const queue_ptr &stream) {
-
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-
-    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
-
-    const int32_t * src1_i32 = (const int32_t *) src1_d;
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
-                                src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_F32:
-            get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q4_0:
-            get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q4_1:
-            get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q5_0:
-            get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q5_1:
-            get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q8_0:
-            get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        default:
-            // TODO: k-quants
-            fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
-            GGML_ABORT("fatal error");
-            break;
-    }
-}
-
-template <class op>
-inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                   const ggml_tensor *src1, ggml_tensor *dst,
-                                   const float *src0_dd, const float *src1_dd,
-                                   float *dst_dd,
-                                   const queue_ptr &main_stream) {
-
-    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-        op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
-             (sycl::half *)dst_dd, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
-        op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
-             main_stream);
-    } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
-        op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
-             main_stream);
-    } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
-        op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
-             main_stream);
-    } else {
-        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
-            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
-        GGML_ABORT("fatal error");
-    }
-}
-
-static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const float *src0_d, const float *src1_d,
-                                float *dst_d,
-                                const queue_ptr &main_stream) {
-
-    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
-
-    (void) src1;
-    (void) src1_d;
-}
-
-inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
-
-    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
-}
-
-inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
-
-    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
-    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
-    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
-    int offset = dst->op_params[3] / 4; // offset in bytes
-
-    acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
-
-    (void) dst;
-}
-
-inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
-
-    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
-}
-
-inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
-
-    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
-}
-
-inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                    const ggml_tensor *src1, ggml_tensor *dst,
-                                    const float *src0_dd, const float *src1_dd,
-                                    float *dst_dd,
-                                    const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-    tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                     const ggml_tensor *src1, ggml_tensor *dst,
-                                     const float *src0_dd, const float *src1_dd,
-                                     float *dst_dd,
-                                     const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                   const ggml_tensor *src1, ggml_tensor *dst,
-                                   const float *src0_dd, const float *src1_dd,
-                                   float *dst_dd, const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                    const ggml_tensor *src1, ggml_tensor *dst,
-                                    const float *src0_dd, const float *src1_dd,
-                                    float *dst_dd,
-                                    const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    }
+    char * dst_ptr = (char *) dst;
 
-    float negative_slope;
-    memcpy(&negative_slope, dst->op_params, sizeof(float));
+    GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
+    GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
+    const enum ggml_type type = src->type;
+    const int64_t ts = ggml_type_size(type);
+    const int64_t bs = ggml_blck_size(type);
+    int64_t i1_diff = i1_high - i1_low;
 
-    leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
+    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
+        // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
+        return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
+                                    kind, *stream));
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+    } else if (nb0 == ts) {
+        return CHECK_TRY_ERROR(
+            dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
+                                    ts * ne0 / bs, i1_diff, kind, *stream));
+    } else {
+        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+            const void * rx = (const void *) ((const char *) x + i1*nb1);
+            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+            // pretend the row is a matrix with cols=1
+            dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
+                rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
+            /*
+            DPCT1001:85: The statement could not be removed.
+            */
+            /*
+            DPCT1000:86: Error handling if-stmt was detected but could not be
+            rewritten.
+            */
+            if (r != 0) return r;
+        }
+        return 0;
+    }
 }
-
-inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
-
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1, ggml_tensor *dst,
-                                 const float *src0_dd, const float *src1_dd,
-                                 float *dst_dd,
-                                 const queue_ptr &main_stream) {
+static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                  const ggml_tensor *src1, ggml_tensor *dst,
+                                  const float *src0_d, const float *src1_d,
+                                  float *dst_d, const queue_ptr &stream) {
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
-    const float sf0 = (float)dst->ne[0]/src0->ne[0];
-    const float sf1 = (float)dst->ne[1]/src0->ne[1];
-    const float sf2 = (float)dst->ne[2]/src0->ne[2];
-    const float sf3 = (float)dst->ne[3]/src0->ne[3];
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
 
-    upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
-                     dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
-                     main_stream);
+    const int32_t * src1_i32 = (const int32_t *) src1_d;
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
+                                src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_F32:
+            get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q4_0:
+            get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        default:
+            // TODO: k-quants
+            fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
+            GGML_ABORT("fatal error");
+            break;
+    }
 }
 
-inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                const ggml_tensor *src1, ggml_tensor *dst,
+                                const float *src0_d, const float *src1_d,
+                                float *dst_d,
+                                const queue_ptr &main_stream) {
 
-    pad_f32_sycl(src0_dd, dst_dd,
-        src0->ne[0], src0->ne[1], src0->ne[2],
-        dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
+    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_repeat>>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
 
     (void) src1;
-    (void) dst;
-    (void) src1_dd;
+    (void) src1_d;
 }
 
+
 inline void ggml_sycl_op_mul_mat_sycl(
     ggml_backend_sycl_context & ctx,
     const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -3379,6 +2625,23 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
     (void) src1_dd;
 }
 
+inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                  const ggml_tensor *src1, ggml_tensor *dst,
+                                  const float *src0_dd, const float *src1_dd,
+                                  float *dst_dd,
+                                  const queue_ptr &main_stream) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ne = ggml_nelements(src0);
+
+    sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
                                   const ggml_tensor *src1, ggml_tensor *dst,
                                   const float *src0_dd, const float *src1_dd,
@@ -3419,6 +2682,25 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_ten
     (void) src1_dd;
 }
 
+inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                 const ggml_tensor *src1, ggml_tensor *dst,
+                                 const float *src0_dd, const float *src1_dd,
+                                 float *dst_dd,
+                                 const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
                                        const ggml_tensor *src1,
                                        ggml_tensor *dst, const float *src0_dd,
@@ -3489,46 +2771,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tenso
     (void) src1_dd;
 }
 
-static void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1, ggml_tensor *dst,
-                                 const ggml_sycl_op_flatten_t op) try {
-    const int64_t nrows0 = ggml_nrows(src0);
-
-    const bool use_src1 = src1 != nullptr;
-    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
-
-    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(              dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-
-    ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
-    ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
-    ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *)  dst->extra;
-
-    // dd = data device
-    float * src0_ddf = (float *) src0->data;
-    float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
-    float *  dst_ddf = (float *) dst->data;
-
-    ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
-    ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
-    ggml_sycl_pool_alloc<float>  dst_f(ctx.pool());
-
-    ggml_sycl_set_device(ctx.device);
-    queue_ptr main_stream = ctx.stream();
-    // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
-        // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
-
-    // do the computation
-    op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
-    // print_ggml_tensor("tensor", dst);
-}
-catch (sycl::exception const &exc) {
-
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
-
 static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
     static bool peer_access_enabled = false;
 
@@ -3908,112 +3150,21 @@ static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tenso
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
 static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
     ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-
-static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
@@ -4632,6 +3783,11 @@ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor
     ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
 }
 
+static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
+}
+
 static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(ggml_is_contiguous(src0));
     ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
@@ -4642,6 +3798,11 @@ static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor
     ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
 }
 
+static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
+}
+
 static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     (void) src0;
     (void) src1;
@@ -4673,6 +3834,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
     ggml_sycl_func_t func;
 
     switch (tensor->op) {
+        case GGML_OP_ARGMAX:
+            func = ggml_sycl_argmax;
+            break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             func = ggml_sycl_op_conv_transpose_1d;
             break;
@@ -4686,19 +3850,32 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
             func = ggml_sycl_dup;
             break;
         case GGML_OP_ADD:
+        case GGML_OP_ADD1: // TODO: more efficient implementation
             func = ggml_sycl_add;
             break;
+        case GGML_OP_SUB:
+            func = ggml_sycl_sub;
+            break;
         case GGML_OP_ACC:
             func = ggml_sycl_acc;
             break;
         case GGML_OP_MUL:
             func = ggml_sycl_mul;
             break;
+        case GGML_OP_LOG:
+            func = ggml_sycl_log;
+            break;
         case GGML_OP_DIV:
             func = ggml_sycl_div;
             break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_NEG:
+                    func = ggml_sycl_neg;
+                    break;
+                case GGML_UNARY_OP_STEP:
+                    func = ggml_sycl_step;
+                    break;
                 case GGML_UNARY_OP_GELU:
                     func = ggml_sycl_gelu;
                     break;
@@ -4714,12 +3891,18 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
                 case GGML_UNARY_OP_RELU:
                     func = ggml_sycl_relu;
                     break;
+                case GGML_UNARY_OP_SIGMOID:
+                    func = ggml_sycl_sigmoid;
+                    break;
                 case GGML_UNARY_OP_HARDSIGMOID:
                     func = ggml_sycl_hardsigmoid;
                     break;
                 case GGML_UNARY_OP_HARDSWISH:
                     func = ggml_sycl_hardswish;
                     break;
+                case GGML_UNARY_OP_EXP:
+                    func = ggml_sycl_exp;
+                    break;
                 default:
                     return false;
             }
@@ -4757,12 +3940,24 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
             }
             func = ggml_sycl_mul_mat_id;
             break;
+        case GGML_OP_OUT_PROD:
+            func = ggml_sycl_op_out_prod;
+            break;
         case GGML_OP_SCALE:
             func = ggml_sycl_scale;
             break;
         case GGML_OP_SQR:
             func = ggml_sycl_sqr;
             break;
+        case GGML_OP_SQRT:
+            func = ggml_sycl_sqrt;
+            break;
+        case GGML_OP_SIN:
+            func = ggml_sycl_sin;
+            break;
+        case GGML_OP_COS:
+            func = ggml_sycl_cos;
+            break;
         case GGML_OP_CLAMP:
             func = ggml_sycl_clamp;
             break;
@@ -4794,6 +3989,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
         case GGML_OP_POOL_2D:
             func = ggml_sycl_pool2d;
             break;
+        case GGML_OP_SUM:
+            func = ggml_sycl_sum;
+            break;
         case GGML_OP_SUM_ROWS:
             func = ggml_sycl_sum_rows;
             break;
@@ -4803,6 +4001,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
         case GGML_OP_TIMESTEP_EMBEDDING:
             func = ggml_sycl_op_timestep_embedding;
             break;
+        case GGML_OP_RWKV_WKV6:
+            func = ggml_sycl_op_rwkv_wkv6;
+            break;
         default:
             return false;
     }
@@ -5125,13 +4326,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             } break;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_NEG:
+                case GGML_UNARY_OP_STEP:
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_SIGMOID:
                 case GGML_UNARY_OP_HARDSIGMOID:
                 case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_TANH:
+                case GGML_UNARY_OP_EXP:
                     return ggml_is_contiguous(op->src[0]);
                 default:
                     return false;
@@ -5168,6 +4373,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 }
                 return true;
             } break;
+        case GGML_OP_OUT_PROD:
+            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
         case GGML_OP_GET_ROWS:
             {
                 switch (op->src[0]->type) {
@@ -5213,10 +4420,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_CONCAT:
             {
                 ggml_type src0_type = op->src[0]->type;
-                int dim = op->op_params[0];
-                return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
             } break;
         case GGML_OP_DUP:
+        case GGML_OP_ARGMAX:
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_REPEAT:
@@ -5225,11 +4432,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_TRANSPOSE:
         case GGML_OP_NORM:
         case GGML_OP_ADD:
+        case GGML_OP_ADD1:
+        case GGML_OP_LOG:
+        case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_RMS_NORM:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
+        case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
         case GGML_OP_CLAMP:
             return true;
         case GGML_OP_CONT:
@@ -5243,6 +4456,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             // TODO: add support for the new F32 operations
             return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_2D:
+        case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
@@ -5251,6 +4465,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_PAD:
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_TIMESTEP_EMBEDDING:
+        case GGML_OP_RWKV_WKV6:
             return true;
         default:
             return false;
@@ -5268,9 +4483,23 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_
     return buft_ctx->device == sycl_ctx->device;
 }
 
+static int64_t get_op_batch_size(const ggml_tensor * op) {
+    switch (op->op) {
+        case GGML_OP_GET_ROWS:
+            return op->ne[1]; // this will increse the speed of prefill in test
+        case GGML_OP_MUL_MAT:
+            return op->ne[1];
+        case GGML_OP_MUL_MAT_ID:
+        case GGML_OP_ROPE:
+            return op->ne[2];
+        default:
+            return ggml_nrows(op);
+    }
+}
+
 static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
     const int min_batch_size = 32;
-    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
+    return get_op_batch_size(op) >= min_batch_size;
     GGML_UNUSED(dev);
 }
 
index d21b5f8dd2627537b9fb9013d1e87695f2646673..85748a5b4c1a398a207536b4af3d6f5da6fc61c3 100644 (file)
@@ -26,5 +26,8 @@
 #include "softmax.hpp"
 #include "tsembd.hpp"
 #include "im2col.hpp"
+#include "wkv6.hpp"
+#include "outprod.hpp"
+#include "element_wise.hpp"
 
 #endif // GGML_SYCL_BACKEND_HPP
index cf5291b31fe9172029e533f5381a0d9f601e1ee2..97ab2003c7fa1201473dff617b7512b0d2cd0bef 100644 (file)
@@ -62,3 +62,43 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
   }
   return sycl_down_blk_size;
 }
+
+void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                 const ggml_tensor *src1, ggml_tensor *dst,
+                                 const ggml_sycl_op_flatten_t op) try {
+    const int64_t nrows0 = ggml_nrows(src0);
+
+    const bool use_src1 = src1 != nullptr;
+    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+    GGML_ASSERT(              dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+
+    ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
+    ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+    ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *)  dst->extra;
+
+    // dd = data device
+    float * src0_ddf = (float *) src0->data;
+    float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
+    float *  dst_ddf = (float *) dst->data;
+
+    ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
+    ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
+    ggml_sycl_pool_alloc<float>  dst_f(ctx.pool());
+
+    ggml_sycl_set_device(ctx.device);
+    queue_ptr main_stream = ctx.stream();
+    // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
+        // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
+
+    // do the computation
+    op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
+    // print_ggml_tensor("tensor", dst);
+}
+catch (sycl::exception const &exc) {
+
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
index bc0faa867dcfe02f8f9f1dacee7ba4c965d5c109..4549fa5e95a098b7f412b5e875e93a0fbd2e475f 100644 (file)
@@ -404,4 +404,262 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
 
 int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
 
+typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                       const ggml_tensor *src1,
+                                       ggml_tensor *dst, const float *src0_dd,
+                                       const float *src1_dd, float *dst_dd,
+                                       const queue_ptr &main_stream);
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s10,*/ int s11, int s12, int s13,
+        const sycl::nd_item<3> &item_ct1) {
+    const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2);
+    const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1));
+    const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                    item_ct1.get_local_id(0)) /
+                   ne3;
+    const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                    item_ct1.get_local_id(0)) %
+                   ne3;
+
+    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  = i_src0;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    for (int i0 = i0s; i0 < ne0;
+         i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
+        const int i10 = i0 % ne10;
+        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    }
+}
+
+template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s10,*/ int s11, int s12, int s13,
+        const sycl::nd_item<3> &item_ct1) {
+
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    const int i3 = i/(ne2*ne1*ne0);
+    const int i2 = (i/(ne1*ne0)) % ne2;
+    const int i1 = (i/ne0) % ne1;
+    const int i0 = i % ne0;
+
+    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  = i_src0;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    const int i10 = i0 % ne10;
+    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+}
+
+
+template<float (*bin_op)(const float, const float)>
+struct bin_bcast_sycl {
+    template <typename src0_t, typename src1_t, typename dst_t>
+    void operator()(ggml_backend_sycl_context & ctx,
+                    const struct ggml_tensor *src0,
+                    const struct ggml_tensor *src1, struct ggml_tensor *dst,
+                    const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
+                    queue_ptr stream) {
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        int nr0 = ne10/ne0;
+        int nr1 = ne11/ne1;
+        int nr2 = ne12/ne2;
+        int nr3 = ne13/ne3;
+
+        int nr[4] = { nr0, nr1, nr2, nr3 };
+
+        // collapse dimensions until first broadcast dimension
+        int64_t cne0[] = {ne0, ne1, ne2, ne3};
+        int64_t cne1[] = {ne10, ne11, ne12, ne13};
+        size_t cnb0[] = {nb0, nb1, nb2, nb3};
+        size_t cnb1[] = {nb10, nb11, nb12, nb13};
+        auto collapse = [](int64_t cne[]) {
+            cne[0] *= cne[1];
+            cne[1] = cne[2];
+            cne[2] = cne[3];
+            cne[3] = 1;
+        };
+
+        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
+            cnb[1] *= cne[1];
+            cnb[2] *= cne[2];
+            cnb[3] *= cne[3];
+        };
+
+        for (int i = 0; i < 4; i++) {
+            if (nr[i] != 1) {
+                break;
+            }
+            if (i > 0) {
+                collapse_nb(cnb0, cne0);
+                collapse_nb(cnb1, cne1);
+                collapse(cne0);
+                collapse(cne1);
+            }
+        }
+        {
+            int64_t ne0 = cne0[0];
+            int64_t ne1 = cne0[1];
+            int64_t ne2 = cne0[2];
+            int64_t ne3 = cne0[3];
+
+            int64_t ne10 = cne1[0];
+            int64_t ne11 = cne1[1];
+            int64_t ne12 = cne1[2];
+            int64_t ne13 = cne1[3];
+
+            size_t nb0 = cnb0[0];
+            size_t nb1 = cnb0[1];
+            size_t nb2 = cnb0[2];
+            size_t nb3 = cnb0[3];
+
+            size_t nb10 = cnb1[0];
+            size_t nb11 = cnb1[1];
+            size_t nb12 = cnb1[2];
+            size_t nb13 = cnb1[3];
+
+            size_t s0 = nb0 / sizeof(dst_t);
+            size_t s1 = nb1 / sizeof(dst_t);
+            size_t s2 = nb2 / sizeof(dst_t);
+            size_t s3 = nb3 / sizeof(dst_t);
+
+            size_t s10 = nb10 / sizeof(src1_t);
+            size_t s11 = nb11 / sizeof(src1_t);
+            size_t s12 = nb12 / sizeof(src1_t);
+            size_t s13 = nb13 / sizeof(src1_t);
+
+            GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s10 == 1);
+
+            const int block_size = 128;
+
+            int64_t hne0 = std::max(ne0/2LL, 1LL);
+
+            sycl::range<3> block_dims(1, 1, 1);
+            block_dims[2] = std::min<unsigned int>(hne0, block_size);
+            block_dims[1] = std::min<unsigned int>(
+                ne1, block_size / (unsigned int)block_dims[2]);
+            block_dims[0] = std::min(
+                std::min<unsigned int>(
+                    ne2 * ne3, block_size / (unsigned int)block_dims[2] /
+                                   (unsigned int)block_dims[1]),
+                64U);
+
+            sycl::range<3> block_nums(
+                (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
+                (ne1 + block_dims[1] - 1) / block_dims[1],
+                (hne0 + block_dims[2] - 1) / block_dims[2]);
+
+            if (block_nums[0] > 65535) {
+                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+                {
+                    dpct::has_capability_or_fail(stream->get_device(),
+                                                 {sycl::aspect::fp16});
+
+                    stream->parallel_for(
+                        sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
+                                              sycl::range<3>(1, 1, block_size),
+                                          sycl::range<3>(1, 1, block_size)),
+                        [=](sycl::nd_item<3> item_ct1) {
+                            k_bin_bcast_unravel<bin_op>(
+                                src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
+                                ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
+                                s13, item_ct1);
+                        });
+                }
+            } else {
+                /*
+                DPCT1049:16: The work-group size passed to the SYCL kernel may
+                exceed the limit. To get the device limit, query
+                info::device::max_work_group_size. Adjust the work-group size if
+                needed.
+                */
+                dpct::has_capability_or_fail(stream->get_device(),
+                                             {sycl::aspect::fp16});
+
+                stream->parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
+                                            ne2, ne3, ne10, ne11, ne12, ne13,
+                                            s1, s2, s3, s11, s12, s13,
+                                            item_ct1);
+                    });
+            }
+        }
+    }
+};
+
+template <class op>
+inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd,
+                                   const queue_ptr &main_stream) {
+
+    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
+             (sycl::half *)dst_dd, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+        op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
+             main_stream);
+    } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
+        op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
+             main_stream);
+    } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
+        op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
+             main_stream);
+    } else {
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ABORT("fatal error");
+    }
+}
+
+
+void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                 const ggml_tensor *src1, ggml_tensor *dst,
+                                 const ggml_sycl_op_flatten_t op);
+
 #endif // GGML_SYCL_COMMON_HPP
index 632eedb9d42b8185199938dfcb53d4432cd70f35..c90c452d8783fdfb53a1b1a015eed578792c3beb 100644 (file)
@@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
           concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
         });
     break;
+  // dim >=2 will be dispatched to the default path
   default:
     stream->parallel_for(
         sycl::nd_range<3>(gridDim *
diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp
new file mode 100644 (file)
index 0000000..e5cd736
--- /dev/null
@@ -0,0 +1,1011 @@
+#include "common.hpp"
+#include "element_wise.hpp"
+
+void acc_f32(const float * x, const float * y, float * dst, const int ne,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+    if (i >= ne) {
+        return;
+    }
+    int src1_idx = i - offset;
+    int oz = src1_idx / nb2;
+    int oy = (src1_idx - (oz * nb2)) / nb1;
+    int ox = src1_idx % nb1;
+    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+    } else {
+        dst[i] = x[i];
+    }
+}
+
+void gelu_f32(const float * x, float * dst, const int k,
+                     const sycl::nd_item<3> &item_ct1) {
+    const float GELU_COEF_A    = 0.044715f;
+    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+
+    float xi = x[i];
+    dst[i] = 0.5f * xi *
+             (1.0f +
+              sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
+}
+
+void silu_f32(const float * x, float * dst, const int k,
+                     const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
+}
+
+void gelu_quick_f32(const float *x, float *dst, int k,
+                           const sycl::nd_item<3> &item_ct1) {
+    const float GELU_QUICK_COEF = -1.702f;
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
+}
+
+void tanh_f32(const float *x, float *dst, int k,
+                     const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sycl::tanh((float)(x[i]));
+}
+
+void relu_f32(const float * x, float * dst, const int k,
+                     const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sycl::fmax((float)(x[i]), (float)0);
+}
+
+void sigmoid_f32(const float * x, float * dst, const int k,
+                            const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i]));
+}
+
+void sqrt_f32(const float * x, float * dst, const int k,
+                            const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sycl::sqrt(x[i]);
+}
+
+void sin_f32(const float * x, float * dst, const int k,
+                            const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sycl::sin(x[i]);
+}
+
+void cos_f32(const float * x, float * dst, const int k,
+                            const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sycl::cos(x[i]);
+}
+
+void hardsigmoid_f32(const float * x, float * dst, const int k,
+                            const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+void hardswish_f32(const float * x, float * dst, const int k,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
+}
+
+void exp_f32(const float * x, float * dst, const int k,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sycl::exp(x[i]);
+}
+
+void log_f32(const float * x, float * dst, const int k,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    float xi = x[i];
+    if (xi <= 0) {
+        dst[i] = -INFINITY;
+    } else {
+        dst[i] = sycl::log(xi);
+    }
+}
+
+void neg_f32(const float * x, float * dst, const int k,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = -x[i];
+}
+
+void step_f32(const float * x, float * dst, const int k,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] > 0.0f;
+}
+
+void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
+                           const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sycl::fmax((float)(x[i]), (float)0) +
+             sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
+}
+
+void sqr_f32(const float * x, float * dst, const int k,
+                    const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * x[i];
+}
+
+void upscale_f32(const float  *x, float *dst, const int nb00, const int nb01,
+                        const int nb02, const int nb03, const int ne10, const int ne11,
+                        const int ne12, const int ne13, const float sf0, const float sf1,
+                        const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
+    int index = item_ct1.get_local_id(0) +
+               item_ct1.get_group(0) * item_ct1.get_local_range(0);
+    if (index >= ne10 * ne11 * ne12 * ne13) {
+        return;
+    }
+    // operation
+    int i10 = index % ne10;
+    int i11 = (index / ne10) % ne11;
+    int i12 = (index / (ne10 * ne11)) % ne12;
+    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
+
+    int i00 = i10 / sf0;
+    int i01 = i11 / sf1;
+    int i02 = i12 / sf2;
+    int i03 = i13 / sf3;
+
+    dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
+}
+
+void pad_f32(const float  *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
+                    const sycl::nd_item<3> &item_ct1) {
+    int nidx = item_ct1.get_local_id(2) +
+               item_ct1.get_group(2) * item_ct1.get_local_range(2);
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
+                     item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
+    if (nidx < ne00 && item_ct1.get_group(1) < ne01 &&
+        item_ct1.get_group(0) < ne02) {
+        int offset_src = nidx + item_ct1.get_group(1) * ne00 +
+                         item_ct1.get_group(0) * ne00 * ne01;
+            dst[offset_dst] = x[offset_src];
+    } else {
+        dst[offset_dst] = 0.0f;
+    }
+}
+
+
+
+void acc_f32_sycl(const float *x, const float *y, float *dst,
+                         const int n_elements, const int ne10, const int ne11,
+                         const int ne12, const int nb1, const int nb2,
+                         const int offset, queue_ptr stream) {
+    int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
+                    item_ct1);
+        });
+}
+
+void gelu_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            gelu_f32(x, dst, k, item_ct1);
+        });
+}
+
+void silu_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            silu_f32(x, dst, k, item_ct1);
+        });
+}
+
+void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
+                                queue_ptr stream) {
+    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            gelu_quick_f32(x, dst, k, item_ct1);
+        });
+}
+
+void tanh_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            tanh_f32(x, dst, k, item_ct1);
+        });
+}
+
+void relu_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            relu_f32(x, dst, k, item_ct1);
+        });
+}
+
+void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
+                                 queue_ptr stream) {
+    const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            hardsigmoid_f32(x, dst, k, item_ct1);
+        });
+}
+
+void hardswish_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            hardswish_f32(x, dst, k, item_ct1);
+        });
+}
+
+void exp_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            exp_f32(x, dst, k, item_ct1);
+        });
+}
+
+void log_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            log_f32(x, dst, k, item_ct1);
+        });
+}
+
+void neg_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            neg_f32(x, dst, k, item_ct1);
+        });
+}
+
+void step_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            step_f32(x, dst, k, item_ct1);
+        });
+}
+
+void sigmoid_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            sigmoid_f32(x, dst, k, item_ct1);
+        });
+}
+
+void sqrt_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            sqrt_f32(x, dst, k, item_ct1);
+        });
+}
+
+void sin_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            sin_f32(x, dst, k, item_ct1);
+        });
+}
+
+void cos_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            cos_f32(x, dst, k, item_ct1);
+        });
+}
+
+void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
+                                const float negative_slope,
+                                queue_ptr stream) {
+    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
+        });
+}
+
+void sqr_f32_sycl(const float *x, float *dst, const int k,
+                         queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            sqr_f32(x, dst, k, item_ct1);
+        });
+}
+
+void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
+                             const int nb02, const int nb03, const int ne10, const int ne11,
+                             const int ne12, const int ne13, const float sf0, const float sf1,
+                             const float sf2, const float sf3, queue_ptr stream) {
+    int dst_size = ne10 * ne11 * ne12 * ne13;
+    int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
+    sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
+    stream->parallel_for(
+        sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
+        [=](sycl::nd_item<1> item_ct1) {
+            upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
+        });
+}
+
+void pad_f32_sycl(const float *x, float *dst, const int ne00,
+                         const int ne01, const int ne02, const int ne0,
+                         const int ne1, const int ne2, queue_ptr stream) {
+    int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
+    sycl::range<3> gridDim(ne2, ne1, num_blocks);
+    stream->parallel_for(
+        sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
+        });
+}
+
+inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                              ggml_tensor *dst, const float *src0_dd,
+                              const float *src1_dd, float *dst_dd,
+                              const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                              ggml_tensor *dst, const float *src0_dd,
+                              const float *src1_dd, float *dst_dd,
+                              const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                    const ggml_tensor *src1, ggml_tensor *dst,
+                                    const float *src0_dd, const float *src1_dd,
+                                    float *dst_dd,
+                                    const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                              ggml_tensor *dst, const float *src0_dd,
+                              const float *src1_dd, float *dst_dd,
+                              const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                              ggml_tensor *dst, const float *src0_dd,
+                              const float *src1_dd, float *dst_dd,
+                              const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                     const ggml_tensor *src1, ggml_tensor *dst,
+                                     const float *src0_dd, const float *src1_dd,
+                                     float *dst_dd,
+                                     const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                    const ggml_tensor *src1, ggml_tensor *dst,
+                                    const float *src0_dd, const float *src1_dd,
+                                    float *dst_dd,
+                                    const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+    leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                 const ggml_tensor *src1, ggml_tensor *dst,
+                                 const float *src0_dd, const float *src1_dd,
+                                 float *dst_dd,
+                                 const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    const float sf0 = (float)dst->ne[0]/src0->ne[0];
+    const float sf1 = (float)dst->ne[1]/src0->ne[1];
+    const float sf2 = (float)dst->ne[2]/src0->ne[2];
+    const float sf3 = (float)dst->ne[3]/src0->ne[3];
+
+    upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                     dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
+                     main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+    pad_f32_sycl(src0_dd, dst_dd,
+        src0->ne[0], src0->ne[1], src0->ne[2],
+        dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
+inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+
+    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+    int offset = dst->op_params[3] / 4; // offset in bytes
+
+    acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
+
+    (void) dst;
+}
+
+inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
+
+    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_add>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
+
+inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
+
+    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
+
+inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
+
+    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
+
+inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
+
+    ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_div>>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
+
+
+void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqrt);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sin);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_cos);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sigmoid);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+
+void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_exp);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_log);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_neg);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_step);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+
+
+void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sub);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp
new file mode 100644 (file)
index 0000000..8152edf
--- /dev/null
@@ -0,0 +1,76 @@
+#ifndef GGML_SYCL_ELEMENTWISE_HPP
+#define GGML_SYCL_ELEMENTWISE_HPP
+
+#include "common.hpp"
+
+static __dpct_inline__ float op_repeat(const float a, const float b) {
+    return b;
+    GGML_UNUSED(a);
+}
+
+static __dpct_inline__ float op_add(const float a, const float b) {
+    return a + b;
+}
+
+static __dpct_inline__ float op_sub(const float a, const float b) {
+    return a - b;
+}
+
+static __dpct_inline__ float op_mul(const float a, const float b) {
+    return a * b;
+}
+
+static __dpct_inline__ float op_div(const float a, const float b) {
+    return a / b;
+}
+
+
+void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+#endif // GGML_SYCL_ELEMENTWISE_HPP
diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp
new file mode 100644 (file)
index 0000000..c2779df
--- /dev/null
@@ -0,0 +1,55 @@
+#include <sycl/sycl.hpp>
+#include "outprod.hpp"
+
+
+void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+    const ggml_tensor* src1, ggml_tensor* dst) {
+
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    // Get SYCL queue
+    dpct::queue_ptr stream = ctx.stream();
+
+    // Dimension checks
+    GGML_ASSERT(ne01 == ne11);  // Inner dimensions must match
+    GGML_ASSERT(ne0 == ne00);   // Output rows match src0 rows
+    GGML_ASSERT(ne1 == ne10);   // Output cols match src1 cols
+
+    // Get data pointers
+    const float* src0_d = (const float*)src0->data;
+    const float* src1_d = (const float*)src1->data;
+    float* dst_d = (float*)dst->data;
+
+    // GEMM parameters
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+
+    // Handle transposition of src1
+    const bool src1_T = ggml_is_transposed(src1);
+    const oneapi::mkl::transpose src1_op =
+        src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
+    const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
+
+    try {
+        // Perform matrix multiplication using oneMKL GEMM
+        oneapi::mkl::blas::gemm(*stream,
+            oneapi::mkl::transpose::nontrans, src1_op,
+            ne0, ne1, ne01,
+            alpha,
+            src0_d, ne00,
+            src1_d, ldb,
+            beta,
+            dst_d, ne0);
+    }
+    catch (sycl::exception const& exc) {
+        std::cerr << exc.what() << std::endl;
+        GGML_ASSERT(false);
+    }
+}
diff --git a/ggml/src/ggml-sycl/outprod.hpp b/ggml/src/ggml-sycl/outprod.hpp
new file mode 100644 (file)
index 0000000..9c04273
--- /dev/null
@@ -0,0 +1,11 @@
+#ifndef GGML_SYCL_OUTPROD_HPP
+#define GGML_SYCL_OUTPROD_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+    const ggml_tensor* src1, ggml_tensor* dst);
+
+
+#endif // GGML_SYCL_OUTPROD_HPP
+
index 340ab8e932bcff87090da02ebeedaae7e81f9c75..af1890727df8f99576a77cb41b5a9a8dbfd20ce6 100644 (file)
 #define SYCL_RELU_BLOCK_SIZE 256
 #define SYCL_HARDSIGMOID_BLOCK_SIZE 256
 #define SYCL_HARDSWISH_BLOCK_SIZE 256
+#define SYCL_EXP_BLOCK_SIZE 256
+#define SYCL_NEG_BLOCK_SIZE 256
+#define SYCL_SIGMOID_BLOCK_SIZE 256
+#define SYCL_SQRT_BLOCK_SIZE 256
+#define SYCL_SIN_BLOCK_SIZE 256
 #define SYCL_SQR_BLOCK_SIZE 256
 #define SYCL_CPY_BLOCK_SIZE 32
 #define SYCL_SCALE_BLOCK_SIZE 256
@@ -41,6 +46,7 @@
 #define SYCL_ACC_BLOCK_SIZE 256
 #define SYCL_IM2COL_BLOCK_SIZE 256
 #define SYCL_POOL2D_BLOCK_SIZE 256
+#define SYCL_ARGMAX_BLOCK_SIZE 256
 #define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
 #define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
 
diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp
new file mode 100644 (file)
index 0000000..4c737f4
--- /dev/null
@@ -0,0 +1,138 @@
+#include <sycl/sycl.hpp>
+#include "wkv6.hpp"
+
+constexpr int WKV_BLOCK_SIZE = 64;  // Matching CUDA_WKV_BLOCK_SIZE
+
+// Helper function for the main kernel
+static void rwkv_wkv_f32_kernel(
+    const int B, const int T, const int C, const int H,
+    const float* k, const float* v, const float* r,
+    const float* tf, const float* td, const float* s,
+    float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
+
+    const int tid = item_ct1.get_local_id(2);
+    const int bid = item_ct1.get_group(2);
+
+    const int head_size = WKV_BLOCK_SIZE;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    // Set up shared memory pointers
+    float* _k = shared_mem;
+    float* _r = _k + head_size;
+    float* _tf = _r + head_size;
+    float* _td = _tf + head_size;
+
+    // Local state array
+    float state[WKV_BLOCK_SIZE];
+
+    // Load initial state
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    // Sync threads before shared memory operations
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    // Load time-mixing parameters
+    _tf[tid] = tf[head_i * head_size + tid];
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    // Main sequence processing loop
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
+         t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
+         t += C) {
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        // Load current timestep data to shared memory
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+        const float _v = v[t];
+        float y = 0;
+
+        // Process in chunks of 4 for better vectorization
+        sycl::float4 k4, r4, tf4, td4, s4, kv4;
+        #pragma unroll
+        for (int j = 0; j < head_size; j += 4) {
+            // Load data in vec4 chunks
+            k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
+            r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
+            tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
+            td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
+            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
+
+            // Compute key-value product
+            sycl::float4 kv4 = k4 * _v;
+
+            // Accumulate weighted sum
+            y += sycl::dot(r4, tf4 * kv4 + s4);
+
+            // Update state
+            s4 = s4 * td4 + kv4;
+
+            // Store updated state
+            state[j] = s4.x();
+            state[j+1] = s4.y();
+            state[j+2] = s4.z();
+            state[j+3] = s4.w();
+        }
+
+        dst[t] = y;
+    }
+
+    // Save final state
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
+    const ggml_tensor* src1, ggml_tensor* dst) {
+
+    const float* k_d = (const float*)dst->src[0]->data;
+    const float* v_d = (const float*)dst->src[1]->data;
+    const float* r_d = (const float*)dst->src[2]->data;
+    const float* tf_d = (const float*)dst->src[3]->data;
+    const float* td_d = (const float*)dst->src[4]->data;
+    const float* s_d = (const float*)dst->src[5]->data;
+    float* dst_d = (float*)dst->data;
+
+    const int64_t B = dst->src[5]->ne[1];
+    const int64_t T = dst->src[0]->ne[3];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[2];
+
+    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
+
+    dpct::queue_ptr stream = ctx.stream();
+
+    // Calculate execution configuration
+    const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
+    sycl::range<3> block_dims(1, 1, C / H);
+    sycl::range<3> grid_dims(1, 1, B * H);
+
+    // Submit kernel
+    stream->submit([&](sycl::handler& cgh) {
+        sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
+
+        cgh.parallel_for(
+            sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                rwkv_wkv_f32_kernel(
+                    B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
+                    item_ct1, shared_mem_acc.get_pointer()
+                );
+            });
+    });
+}
diff --git a/ggml/src/ggml-sycl/wkv6.hpp b/ggml/src/ggml-sycl/wkv6.hpp
new file mode 100644 (file)
index 0000000..ddfa337
--- /dev/null
@@ -0,0 +1,10 @@
+#ifndef GGML_SYCL_WKV6_HPP
+#define GGML_SYCL_WKV6_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+    const ggml_tensor *src1, ggml_tensor * dst);
+
+
+#endif // GGML_SYCL_WKV6_HPP
index 266a0d6f044db8bb0dc6d4a238dbc0be866dcdf5..bc034015f470ccbf2bb31cb8205494caba43d342 100644 (file)
@@ -975,7 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "WIN_UNPART",
     "GET_REL_POS",
     "ADD_REL_POS",
-    "RWKV_WKV",
+    "RWKV_WKV6",
 
     "UNARY",
 
@@ -1070,7 +1070,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "win_unpart(x)",
     "get_rel_pos(x)",
     "add_rel_pos(x)",
-    "rwkv_wkv(k, v, r, tf, td, s)",
+    "rwkv_wkv6(k, v, r, tf, td, s)",
 
     "unary(x)",
 
@@ -4503,9 +4503,9 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
     return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
 }
 
-// ggml_rwkv_wkv
+// ggml_rwkv_wkv6
 
-struct ggml_tensor * ggml_rwkv_wkv(
+struct ggml_tensor * ggml_rwkv_wkv6(
         struct ggml_context * ctx,
         struct ggml_tensor  * k,
         struct ggml_tensor  * v,
@@ -4537,7 +4537,7 @@ struct ggml_tensor * ggml_rwkv_wkv(
     const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    result->op     = GGML_OP_RWKV_WKV;
+    result->op     = GGML_OP_RWKV_WKV6;
     result->src[0] = k;
     result->src[1] = v;
     result->src[2] = r;
@@ -6084,7 +6084,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_GET_REL_POS:
         case GGML_OP_ADD_REL_POS:
-        case GGML_OP_RWKV_WKV:
+        case GGML_OP_RWKV_WKV6:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32: