]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : implement REGLU/GEGLU/SWIGLU ops (llama/14158)
authorSigbjørn Skjæret <redacted>
Sun, 29 Jun 2025 09:04:10 +0000 (11:04 +0200)
committerGeorgi Gerganov <redacted>
Tue, 1 Jul 2025 14:54:53 +0000 (17:54 +0300)
* implement unary REGLU/GEGLU/SWIGLU cpu ops

* relax constraints

* duplicate shape of source

* fix ggml_vec_geglu_f16

* special case gated ops

* implement unary REGLU/GEGLU/SWIGLU cuda ops

* tighten constraints again

* refactor into GGML_GLU_OP

* metal : add glu kernels

ggml-ci

* add CUDA_GLU_BLOCK_SIZE [no ci]

* more constraints and use 64bit ints

ggml-ci

* 64bit multiplication [no ci]

* implement swapped variants (cpu/cuda)

* update comment [no ci]

ggml-ci

* Vulkan: Add GLU ops and shaders

* SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate

* ggml : implement GLU for split up/gate (llama/14181)

* implement GLU for split up/gate

* add tests for ggml_glu_split

* Vulkan: Implement glu_split logic and shader support

* add split to logging [no ci]

* SYCL: refactor element_size ops and add split up and gate support to gated kernels

* SYCL: switch GEGLU to use tanh approximation

---------

Co-authored-by: 0cc4m <redacted>
Co-authored-by: Akarshan <redacted>
* GGML: increase OP count in assertion

* Refactor: Optimize SYCL element-wise operations with unary function inlining

This commit refactors the SYCL element-wise operations to improve performance by:

- Inlining unary operations (sgn, abs, elu, gelu, silu, etc.) to reduce kernel launch overhead.
- Introducing helper functions `op_xxx` for each unary operation to encapsulate the logic.
- Replacing direct kernel calls with calls to these inlined functions.
- Using `__dpct_inline__` to encourage compiler inlining.
- Minor code cleanup and consistency improvements.

The changes aim to reduce kernel launch overhead and improve the overall efficiency of element-wise operations on SYCL devices.

* vulkan: Increase workgroup size for GLU, for performance (llama/14345)

* vulkan: Increase workgroup size for GLU, for performance

* vulkan: change GLU shaders to do one element per invocation rather than one row per workgroup

* merge fix

* metal : add support for split and swap

ggml-ci

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: 0cc4m <redacted>
Co-authored-by: Akarshan <redacted>
Co-authored-by: Jeff Bolz <redacted>
23 files changed:
ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/ops.h
ggml/src/ggml-cpu/vec.cpp
ggml/src/ggml-cpu/vec.h
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/unary.cu
ggml/src/ggml-cuda/unary.cuh
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
ggml/src/ggml-sycl/element_wise.cpp
ggml/src/ggml-sycl/element_wise.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
ggml/src/ggml.c

index 042a13e055bd3e80c86d764d97f96255a86fc589..3a04ad58a4930d63714c849e1c3f1764fbcda8a9 100644 (file)
@@ -520,6 +520,8 @@ extern "C" {
         GGML_OP_CROSS_ENTROPY_LOSS_BACK,
         GGML_OP_OPT_STEP_ADAMW,
 
+        GGML_OP_GLU,
+
         GGML_OP_COUNT,
     };
 
@@ -543,6 +545,14 @@ extern "C" {
         GGML_UNARY_OP_COUNT,
     };
 
+    enum ggml_glu_op {
+        GGML_GLU_OP_REGLU,
+        GGML_GLU_OP_GEGLU,
+        GGML_GLU_OP_SWIGLU,
+
+        GGML_GLU_OP_COUNT,
+    };
+
     enum ggml_object_type {
         GGML_OBJECT_TYPE_TENSOR,
         GGML_OBJECT_TYPE_GRAPH,
@@ -658,6 +668,7 @@ extern "C" {
     GGML_API const char * ggml_op_symbol(enum ggml_op   op);
 
     GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
+    GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
     GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
 
     GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
@@ -762,6 +773,7 @@ extern "C" {
     GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
 
     GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
+    GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
 
     GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
     GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
@@ -1090,6 +1102,63 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // gated linear unit ops
+    // A: n columns, r rows,
+    // result is n / 2 columns, r rows,
+    // expects gate in second half of row, unless swapped is true
+    GGML_API struct ggml_tensor * ggml_glu(
+            struct ggml_context * ctx,
+             struct ggml_tensor * a,
+             enum ggml_glu_op     op,
+             bool                 swapped);
+
+    GGML_API struct ggml_tensor * ggml_reglu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_reglu_swapped(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_geglu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_geglu_swapped(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_swiglu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_swiglu_swapped(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // A: n columns, r rows,
+    // B: n columns, r rows,
+    GGML_API struct ggml_tensor * ggml_glu_split(
+            struct ggml_context * ctx,
+             struct ggml_tensor * a,
+             struct ggml_tensor * b,
+             enum ggml_glu_op     op);
+
+    GGML_API struct ggml_tensor * ggml_reglu_split(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_geglu_split(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_swiglu_split(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // normalize along rows
     GGML_API struct ggml_tensor * ggml_norm(
             struct ggml_context * ctx,
index 2042ee71f1f804c647e715f2ba770721f25e25d2..1d68cde71a65e94d2493968d2d4347fb45f891d3 100644 (file)
@@ -1949,6 +1949,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_unary(params, tensor);
             } break;
+        case GGML_OP_GLU:
+            {
+                ggml_compute_forward_glu(params, tensor);
+            } break;
         case GGML_OP_GET_REL_POS:
             {
                 ggml_compute_forward_get_rel_pos(params, tensor);
@@ -2159,6 +2163,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                     GGML_ABORT("fatal error");
             }
             break;
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(node)) {
+                case GGML_GLU_OP_REGLU:
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_SWIGLU:
+                    {
+                        n_tasks = n_threads;
+                    } break;
+                default:
+                    GGML_ABORT("fatal error");
+            }
+            break;
         case GGML_OP_SILU_BACK:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
index 7ead13f42a3a526f653d1fb4f5d22cd601d82985..3b07e2497416962b724ea7d16b3eceb012f7db9d 100644 (file)
@@ -3184,6 +3184,435 @@ void ggml_compute_forward_silu_back(
     }
 }
 
+// ggml_compute_forward_reglu
+
+static void ggml_compute_forward_reglu_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * src0_p = (float *) (src0_d + i1*src0_o);
+        float * src1_p = (float *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_reglu_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
+        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_reglu(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_reglu_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_reglu_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_geglu
+
+static void ggml_compute_forward_geglu_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * src0_p = (float *) (src0_d + i1*src0_o);
+        float * src1_p = (float *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_geglu_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
+        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_geglu(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_geglu_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_geglu_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_swiglu
+
+static void ggml_compute_forward_swiglu_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * src0_p = (float *) (src0_d + i1*src0_o);
+        float * src1_p = (float *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_swiglu_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
+        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_swiglu(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_swiglu_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_swiglu_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_norm
 
 static void ggml_compute_forward_norm_f32(
@@ -8057,6 +8486,34 @@ void ggml_compute_forward_unary(
     }
 }
 
+//ggml_compute_forward_glu
+
+void ggml_compute_forward_glu(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_glu_op op = ggml_get_glu_op(dst);
+
+    switch (op) {
+        case GGML_GLU_OP_REGLU:
+            {
+                ggml_compute_forward_reglu(params, dst);
+            } break;
+        case GGML_GLU_OP_GEGLU:
+            {
+                ggml_compute_forward_geglu(params, dst);
+            } break;
+        case GGML_GLU_OP_SWIGLU:
+            {
+                ggml_compute_forward_swiglu(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_get_rel_pos
 
 static void ggml_compute_forward_get_rel_pos_f16(
index 3a395fdcd9f048d76684b14da9de6ac5d6d16ee1..5b384e4ba5fce92259c8bf005358fa679d23fa7c 100644 (file)
@@ -94,6 +94,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
 void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
index 5e34d79a1695f232f8da6c1c66dacd9e216394bf..ed5d7aefc35b324884bed2ef7f17d628415acc6c 100644 (file)
@@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
     }
 }
 
+void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
+    int i = 0;
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+    for (; i + 15 < n; i += 16) {
+        _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
+    }
+#elif defined(__AVX2__) && defined(__FMA__)
+    for (; i + 7 < n; i += 8) {
+        _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
+    }
+#elif defined(__SSE2__)
+    for (; i + 3 < n; i += 4) {
+        _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
+    }
+#elif defined(__ARM_NEON) && defined(__aarch64__)
+    for (; i + 3 < n; i += 4) {
+        vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
+    }
+#endif
+    for (; i < n; ++i) {
+        y[i] = ggml_silu_f32(x[i]) * g[i];
+    }
+}
+
 ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
     int i = 0;
     ggml_float sum = 0;
index 84f6c0e6d26c4a493752e1aa65e1485f5fb18a16..ebd4b75613451160f80a70b8a4ab2cb237fba435 100644 (file)
@@ -905,6 +905,60 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
     }
 }
 
+inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
+    }
+}
+
+inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f);
+    }
+}
+
+#ifdef GGML_GELU_FP16
+inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        if (x[i] <= -10.0f) {
+            y[i] = 0.0f;
+        } else if (x[i] >= 10.0f) {
+            y[i] = x[i] * g[i];
+        } else {
+            ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+            memcpy(&t, &fp16, sizeof(uint16_t));
+            y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
+        }
+    }
+}
+#else
+inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_f32(x[i]) * g[i];
+    }
+}
+#endif
+
+inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(g[i]);
+        y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
+    }
+}
+
+void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
+
+inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_FP16_TO_FP32(x[i]);
+        float w = GGML_FP16_TO_FP32(g[i]);
+        y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
+    }
+}
+
 inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     ggml_float sum = 0.0;
index 811422f38507325db6af87c16506830988ce5e33..086f9a56c4acaf392aabf1835f212a6d355430c8 100644 (file)
@@ -2303,6 +2303,21 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                     return false;
             }
             break;
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(dst)) {
+                case GGML_GLU_OP_REGLU:
+                    ggml_cuda_op_reglu(ctx, dst);
+                    break;
+                case GGML_GLU_OP_GEGLU:
+                    ggml_cuda_op_geglu(ctx, dst);
+                    break;
+                case GGML_GLU_OP_SWIGLU:
+                    ggml_cuda_op_swiglu(ctx, dst);
+                    break;
+                default:
+                    return false;
+            }
+            break;
         case GGML_OP_NORM:
             ggml_cuda_op_norm(ctx, dst);
             break;
@@ -3096,6 +3111,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                     return false;
             }
             break;
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(op)) {
+                case GGML_GLU_OP_REGLU:
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_SWIGLU:
+                    return ggml_is_contiguous_1(op->src[0]);
+                default:
+                    return false;
+            }
+            break;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             {
index 2c0375fbe3cf6156eb68050079da01e2ec1d23a3..ba3c0f13762b0abb8a62f8bcdc2b58ef83278454 100644 (file)
@@ -196,6 +196,95 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_unary<op_log>(ctx, dst);
 }
 
+/* gated ops */
+
+template <float (*op)(float), typename T>
+static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
+    const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    // perform base op and multiply with gate (either offset in same tensor or a separate one)
+    const int64_t j0 = (i / n) * o0 + (i % n);
+    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+
+    dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
+}
+
+template <float (*op)(float), typename T>
+static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
+    const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
+    unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
+}
+
+template <float (*op)(float)>
+void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    void * src0_d = src0->data;
+    void * src1_d = src1 ? src1->data : src0->data;
+    const int64_t src0_o = src0->nb[1];
+    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+    void * dst_d = dst->data;
+    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+        GGML_ASSERT(src1->ne[0] == nc);
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+
+    if (src0->type == GGML_TYPE_F16) {
+        half * src0_p = (half *) src0_d;
+        half * src1_p = (half *) src1_d;
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream);
+    } else {
+        float * src0_p = (float *) src0_d;
+        float * src1_p = (float *) src1_d;
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream);
+    }
+}
+
+void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary_gated<op_relu>(ctx, dst);
+}
+
+void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary_gated<op_gelu>(ctx, dst);
+}
+
+void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
+}
+
 /* silu_back */
 
 static __device__ __forceinline__ float op_silu_back(float grad, float x) {
index 6686fc17e9193d7938c323d45599a7772726ff15..9094f1d0bad3740a18618caee851cecc911fe364 100644 (file)
@@ -15,6 +15,7 @@
 #define CUDA_SQRT_BLOCK_SIZE 256
 #define CUDA_SIN_BLOCK_SIZE 256
 #define CUDA_COS_BLOCK_SIZE 256
+#define CUDA_GLU_BLOCK_SIZE 256
 
 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
@@ -57,3 +58,9 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 260440aedde00c05dce10b92eb158d1a600c4abc..7a9aab31684e18522859c70e03cbfa7c4412a150 100644 (file)
@@ -422,6 +422,17 @@ typedef struct {
     int32_t  KHW; // KH * KW, pre-computed on CPU to save GPU resources
 } ggml_metal_kargs_im2col;
 
+typedef struct{
+    int32_t  ne00;
+    uint64_t nb01;
+    int32_t  ne10;
+    uint64_t nb11;
+    int32_t  ne0;
+    uint64_t nb1;
+    int32_t  i00;
+    int32_t  i10;
+} ggml_metal_kargs_glu;
+
 typedef struct {
     int64_t  ne00;
     int64_t  ne01;
index 349f0ff998ed13d7e2262e42cea2efdea19f5659..12a366957891c63e43a8ebd9d2191c33b6ecd828 100644 (file)
@@ -526,6 +526,9 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
     GGML_METAL_KERNEL_TYPE_NEG,
+    GGML_METAL_KERNEL_TYPE_REGLU,
+    GGML_METAL_KERNEL_TYPE_GEGLU,
+    GGML_METAL_KERNEL_TYPE_SWIGLU,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
     GGML_METAL_KERNEL_TYPE_MEAN,
     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1502,6 +1505,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                             sin,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU,                           reglu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU,                           geglu,                           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU,                          swiglu,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
@@ -1680,6 +1686,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 default:
                     return false;
             }
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(op)) {
+                case GGML_GLU_OP_REGLU:
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_SWIGLU:
+                    return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+               default:
+                    return false;
+            }
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
@@ -2419,6 +2434,62 @@ static bool ggml_metal_encode_node(
                     GGML_ABORT("fatal error");
                 }
             } break;
+        case GGML_OP_GLU:
+            {
+                GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+                if (src1) {
+                    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+                }
+
+                id<MTLComputePipelineState> pipeline = nil;
+
+                switch (ggml_get_glu_op(node)) {
+                    case GGML_GLU_OP_REGLU:
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
+                        break;
+                    case GGML_GLU_OP_GEGLU:
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
+                        break;
+                    case GGML_GLU_OP_SWIGLU:
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
+                        break;
+                    default:
+                        GGML_ABORT("fatal error");
+                }
+
+                const int32_t swp = ((const int32_t *) dst->op_params)[1];
+
+                const int32_t i00 = swp ? ne0 : 0;
+                const int32_t i10 = swp ? 0 : ne0;
+
+                ggml_metal_kargs_glu args = {
+                    /*.ne00 =*/ ne00,
+                    /*.nb01 =*/ nb01,
+                    /*.ne10 =*/ src1 ? ne10 : ne00,
+                    /*.nb11 =*/ src1 ? nb11 : nb01,
+                    /*.ne0  =*/ ne0,
+                    /*.nb1  =*/ nb1,
+                    /*.i00  =*/ src1 ? 0 : i00,
+                    /*.i10  =*/ src1 ? 0 : i10,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                if (src1) {
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                } else {
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                }
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                [encoder setBytes:&args length:sizeof(args) atIndex:3];
+
+                const int64_t nrows = ggml_nrows(src0);
+
+                const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
         case GGML_OP_SQR:
             {
                 GGML_ASSERT(ggml_is_contiguous(src0));
index 984a0ab503e7d9a0aecab223ac5e5f193c8cdc84..fc3cfe35a34bf96c3f989d6143deaeba3aae8048 100644 (file)
@@ -1191,6 +1191,70 @@ kernel void kernel_neg(
     dst[tpig] = -src0[tpig];
 }
 
+kernel void kernel_reglu(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant ggml_metal_kargs_glu & args,
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
+    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
+    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
+
+    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
+        const float x0 = src0_row[i0];
+        const float x1 = src1_row[i0];
+
+        dst_row[i0] = x0*x1*(x0 > 0.0f);
+    }
+}
+
+kernel void kernel_geglu(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant ggml_metal_kargs_glu & args,
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
+    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
+    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
+
+    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
+        const float x0 = src0_row[i0];
+        const float x1 = src1_row[i0];
+
+        const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
+
+        dst_row[i0] = gelu*x1;
+    }
+}
+
+kernel void kernel_swiglu(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant ggml_metal_kargs_glu & args,
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
+    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
+    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
+
+    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
+        const float x0 = src0_row[i0];
+        const float x1 = src1_row[i0];
+
+        const float silu = x0 / (1.0f + exp(-x0));
+
+        dst_row[i0] = silu*x1;
+    }
+}
+
 template <bool norm>
 kernel void kernel_sum_rows(
         constant ggml_metal_kargs_sum_rows & args,
index c56924ce8322fcc8d04ffa927143cc481d81f401..c7788bdb6bf8c4618fade66872175621625d105d 100644 (file)
@@ -1,12 +1,19 @@
 #include "common.hpp"
+#include "ggml-sycl/presets.hpp"
 #include "ggml.h"
 #include "element_wise.hpp"
 
+#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
+    for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
+
+#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
+    (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
+
+
 static void acc_f32(const float * x, const float * y, float * dst, const int ne,
     const int ne10, const int ne11, const int ne12,
-    const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+    const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
+    const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
     if (i >= ne) {
         return;
     }
@@ -21,535 +28,375 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne,
     }
 }
 
+/* Unary OP funcs */
 template<typename T>
-static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
-    for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
-        dst[i] = x[i] > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x[i] < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
-    }
+static __dpct_inline__ T op_sgn(T x) {
+    return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
 }
 
 template<typename T>
-static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
-    for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
-        dst[i] = sycl::fabs(x[i]);
-    }
+static __dpct_inline__ T op_abs(T x) {
+    return sycl::fabs(x);
 }
 
 template<typename T>
-static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
-    for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
-        dst[i] = (x[i] > static_cast<T>(0.f)) ? x[i] : sycl::expm1(x[i]);
-    }
+static __dpct_inline__ T op_elu(T x) {
+    return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
 }
 
 template<typename T>
-static void gelu(const T * x, T * dst, const int k,
-                     const sycl::nd_item<3> &item_ct1) {
+static __dpct_inline__ T op_gelu(T x) {
     const T GELU_COEF_A    = static_cast<T>(0.044715f);
     const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-
-    float xi = x[i];
-    dst[i] = static_cast<T>(0.5f) * xi *
-             (static_cast<T>(1.0f) +
-              sycl::tanh(SQRT_2_OVER_PI * xi * (static_cast<T>(1.0f) + GELU_COEF_A * xi * xi)));
+    return static_cast<T>(0.5f) * x *
+           (static_cast<T>(1.0f) +
+            sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
 }
 
 template<typename T>
-static void silu(const T * x, T * dst, const int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
+static __dpct_inline__ T op_silu(T x) {
+    return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
 }
 
 template<typename T>
-static void gelu_quick(const T *x, T *dst, int k,
-                           const sycl::nd_item<3> &item_ct1) {
-    const float GELU_QUICK_COEF = -1.702f;
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
+static __dpct_inline__ T op_gelu_quick(T x) {
+    const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
+    return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
 }
 
 template<typename T>
-static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
+static __dpct_inline__ T op_gelu_erf(T x) {
     const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
-    for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
-       auto x_i = x[i];
-        dst[i] = static_cast<T>(0.5f) * x_i * (static_cast<T>(1.0f) + sycl::erf(x_i * SQRT_2_INV));
-    }
+    return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
 }
 
 template<typename T>
-static void tanh(const T *x, T *dst, int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::tanh((x[i]));
+static __dpct_inline__ T op_tanh(T x) {
+    return sycl::tanh(x);
 }
 
 template<typename T>
-static void relu(const T * x, T * dst, const int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::fmax((x[i]), static_cast<T>(0));
+static __dpct_inline__ T op_relu(T x) {
+    return sycl::fmax(x, static_cast<T>(0));
 }
 
 template<typename T>
-static void sigmoid(const T * x, T * dst, const int k,
-                            const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = 1.0f / (static_cast<T>(1.0f) + sycl::native::exp(-x[i]));
+static __dpct_inline__ T op_sigmoid(T x) {
+    return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
 }
 
 template<typename T>
-static void sqrt(const T * x, T * dst, const int k,
-                            const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::sqrt(x[i]);
+static __dpct_inline__ T op_sqrt(T x) {
+    return sycl::sqrt(x);
 }
 
 template<typename T>
-static void sin(const T * x, T * dst, const int k,
-                            const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::sin(x[i]);
+static __dpct_inline__ T op_sin(T x) {
+    return sycl::sin(x);
 }
 
 template<typename T>
-static void cos(const T * x, T * dst, const int k,
-                            const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::cos(x[i]);
+static __dpct_inline__ T op_cos(T x) {
+    return sycl::cos(x);
 }
 
 template<typename T>
-static void hardsigmoid(const T * x, T * dst, const int k,
-                            const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
+static __dpct_inline__ T op_hardsigmoid(T x) {
+    return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
 }
 
 template<typename T>
-static void hardswish(const T * x, T * dst, const int k,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x[i] + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
+static __dpct_inline__ T op_hardswish(T x) {
+    return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
 }
 
 template<typename T>
-static void exp(const T * x, T * dst, const int k,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+static __dpct_inline__ T op_exp(T x) {
+    return sycl::exp(x);
+}
 
-    if (i >= k) {
-        return;
+template<typename T>
+static __dpct_inline__ T op_log(T x) {
+    if (x <= static_cast<T>(0)) {
+        return neg_infinity<T>();
     }
-    dst[i] = sycl::exp(x[i]);
+    return sycl::log(x);
 }
 
 template<typename T>
-static void log(const T * x, T * dst, const int k,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+static __dpct_inline__ T op_neg(T x) {
+    return -x;
+}
 
-    if (i >= k) {
-        return;
-    }
-    T xi = x[i];
-    if (xi <= 0) {
-        dst[i] = neg_infinity<T>();
-    } else {
-        dst[i] = sycl::log(xi);
-    }
+template<typename T>
+static __dpct_inline__ T op_step(T x) {
+    return (x > static_cast<T>(0.0f)) ? static_cast<T>(1.0f) : static_cast<T>(0.0f);
 }
 
 template<typename T>
-static void neg(const T * x, T * dst, const int k,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
+    T neg_slope_T = static_cast<T>(negative_slope);
+    return sycl::fmax(x, static_cast<T>(0)) +
+           sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
+}
 
-    if (i >= k) {
-        return;
-    }
-    dst[i] = -x[i];
+template<typename T>
+static __dpct_inline__ T op_sqr(T x) {
+    return x * x;
 }
 
 template<typename T>
-static void step(const T * x, T * dst, const int k,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
+    return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
+}
 
-    if (i >= k) {
-        return;
+template<typename T>
+static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_sgn(x[i]);
     }
-    dst[i] = x[i] > static_cast<T>(0.0f);
 }
 
 template<typename T>
-static void leaky_relu(const T *x, T *dst, const int k, const float negative_slope,
-                           const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= k) {
-        return;
+static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_abs(x[i]);
     }
-    dst[i] = sycl::fmax((x[i]), static_cast<T>(0)) +
-             sycl::fmin((x[i]), static_cast<T>(0.0f)) * negative_slope;
 }
 
 template<typename T>
-static void sqr(const T * x, T * dst, const int k,
-                    const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
+static void unary_op_elu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_elu(x[i]);
     }
-    dst[i] = x[i] * x[i];
 }
 
-template<typename  T>
-static void upscale(const T  *x, T *dst, const int nb00, const int nb01,
-                        const int nb02, const int nb03, const int ne10, const int ne11,
-                        const int ne12, const int ne13, const float sf0, const float sf1,
-                        const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
-    int index = item_ct1.get_local_id(0) +
-               item_ct1.get_group(0) * item_ct1.get_local_range(0);
-    if (index >= ne10 * ne11 * ne12 * ne13) {
-        return;
+template<typename T>
+static void unary_op_gelu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_gelu(x[i]);
     }
-    // operation
-    int i10 = index % ne10;
-    int i11 = (index / ne10) % ne11;
-    int i12 = (index / (ne10 * ne11)) % ne12;
-    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
-
-    int i00 = i10 / sf0;
-    int i01 = i11 / sf1;
-    int i02 = i12 / sf2;
-    int i03 = i13 / sf3;
-
-    dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
 }
 
-template <typename T>
-static void pad(const T  *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
-                    const sycl::nd_item<3> &item_ct1) {
-    int nidx = item_ct1.get_local_id(2) +
-               item_ct1.get_group(2) * item_ct1.get_local_range(2);
-    if (nidx >= ne0) {
-        return;
+template<typename T>
+static void unary_op_silu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_silu(x[i]);
     }
+}
 
-    // operation
-    int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
-                     item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
-    if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
-        int offset_src = nidx + item_ct1.get_group(1) * ne00 +
-                         item_ct1.get_group(0) * ne00 * ne01;
-            dst[offset_dst] = x[offset_src];
-    } else {
-        dst[offset_dst] = static_cast<T>(0.0f);
+template<typename T>
+static void unary_op_gelu_quick_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_gelu_quick(x[i]);
     }
 }
 
-
 template<typename T>
-static void clamp(const T * x, T * dst, const float min, const float max, const int k,
-                      const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
+static void unary_op_gelu_erf_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_gelu_erf(x[i]);
     }
-
-    dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
 }
 
-static void acc_f32_sycl(const float *x, const float *y, float *dst,
-                         const int n_elements, const int ne10, const int ne11,
-                         const int ne12, const int nb1, const int nb2,
-                         const int offset, queue_ptr stream) {
-    int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) {
-                          acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, item_ct1);
-                      });
+template<typename T>
+static void unary_op_tanh_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_tanh(x[i]);
+    }
 }
 
 template<typename T>
-static void gelu_sycl(const T *x, T *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { gelu(x, dst, k, item_ct1); });
+static void unary_op_relu_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_relu(x[i]);
+    }
 }
 
 template<typename T>
-static void silu_sycl(const T *x, T *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { silu(x, dst, k, item_ct1); });
+static void unary_op_sigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_sigmoid(x[i]);
+    }
 }
 
 template<typename T>
-static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
-    // hard code for now
-    const int num_blocks = ceil_div(k, 256);
-    sycl_parallel_for(
-        stream, sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)),
-        [=](sycl::nd_item<3> item_ct1) { sgn(x, dst, k, item_ct1); });
+static void unary_op_sqrt_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_sqrt(x[i]);
+    }
 }
 
 template<typename T>
-static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
-    // hard code for now
-    const int num_blocks = ceil_div(k, 256);
-    sycl_parallel_for(
-        stream,
-        sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
-        [=](sycl::nd_item<3> item_ct1) { abs_op(x, dst, k, item_ct1); });
+static void unary_op_sin_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_sin(x[i]);
+    }
 }
 
-
 template<typename T>
-static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) {
-    // hard code for now
-    const int num_blocks = ceil_div(k, 256);
-    sycl_parallel_for(
-        stream,
-        sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)),
-        [=](sycl::nd_item<3> item_ct1) { elu_op(x, dst, k, item_ct1); });
+static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_cos(x[i]);
+    }
 }
 
 template<typename T>
-static void gelu_quick_sycl(const T *x, T *dst, const int k,
-                                queue_ptr stream) {
-    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { gelu_quick(x, dst, k, item_ct1); });
+static void unary_op_hardsigmoid_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_hardsigmoid(x[i]);
+    }
 }
 
-
 template<typename T>
-static void gelu_erf_sycl(const T *x, T *dst, const int k,
-                                queue_ptr stream) {
-    const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { gelu_erf(x, dst, k, item_ct1); });
+static void unary_op_hardswish_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_hardswish(x[i]);
+    }
 }
 
 template<typename T>
-static void tanh_sycl(const T *x, T *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { tanh(x, dst, k, item_ct1); });
+static void unary_op_exp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_exp(x[i]);
+    }
 }
 
 template<typename T>
-static void relu_sycl(const T *x, T *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { relu(x, dst, k, item_ct1); });
+static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_log(x[i]);
+    }
 }
 
 template<typename T>
-static void hardsigmoid_sycl(const T *x, T *dst, const int k,
-                                 queue_ptr stream) {
-    const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
-    sycl_parallel_for(
-        stream,
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) { hardsigmoid(x, dst, k, item_ct1); });
+static void unary_op_neg_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_neg(x[i]);
+    }
 }
 
 template<typename T>
-static void hardswish_sycl(const T *x, T *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
-    sycl_parallel_for(
-        stream,
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) { hardswish(x, dst, k, item_ct1); });
+static void unary_op_step_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_step(x[i]);
+    }
 }
 
 template<typename T>
-static void exp_sycl(const T *x, T *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { exp(x, dst, k, item_ct1); });
+static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_leaky_relu(x[i], negative_slope);
+    }
 }
 
 template<typename T>
-static void log_sycl(const T *x, T *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { log(x, dst, k, item_ct1); });
+static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_sqr(x[i]);
+    }
 }
 
 template<typename T>
-static void neg_sycl(const T *x, T *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { neg(x, dst, k, item_ct1); });
+static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = op_clamp(x[i], min_val, max_val);
+    }
 }
 
-template<typename T>
-static void step_sycl(const T *x, T *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { step(x, dst, k, item_ct1); });
+template<typename  T>
+static void upscale(const T  *x, T *dst, const int nb00, const int nb01,
+                        const int nb02, const int nb03, const int ne10, const int ne11,
+                        const int ne12, const int ne13, const float sf0, const float sf1,
+                        const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
+    int index = item_ct1.get_local_id(0) +
+               item_ct1.get_group(0) * item_ct1.get_local_range(0);
+    if (index >= ne10 * ne11 * ne12 * ne13) {
+        return;
+    }
+    // operation
+    int i10 = index % ne10;
+    int i11 = (index / ne10) % ne11;
+    int i12 = (index / (ne10 * ne11)) % ne12;
+    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
+
+    int i00 = static_cast<int>(i10 / sf0);
+    int i01 = static_cast<int>(i11 / sf1);
+    int i02 = static_cast<int>(i12 / sf2);
+    int i03 = static_cast<int>(i13 / sf3);
+
+    dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
 }
 
-template<typename T>
-static void sigmoid_sycl(const T *x, T *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE;
-    sycl_parallel_for(
-        stream,
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) { sigmoid(x, dst, k, item_ct1); });
+template <typename T>
+static void pad(const T  *x, T *dst, const int ne0, const int ne00, const int ne01, const int ne02,
+                    const sycl::nd_item<3> &item_ct1) {
+    int nidx = SYCL_LOCAL_ID_CALC(item_ct1, 2);
+    if (nidx >= ne0) {
+        return;
+    }
+
+    // operation
+    int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
+                     item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
+    if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) {
+        int offset_src = nidx + item_ct1.get_group(1) * ne00 +
+                         item_ct1.get_group(0) * ne00 * ne01;
+            dst[offset_dst] = x[offset_src];
+    } else {
+        dst[offset_dst] = static_cast<T>(0.0f);
+    }
 }
 
 template<typename T>
-static void sqrt_sycl(const T *x, T *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { sqrt(x, dst, k, item_ct1); });
+static void clamp(const T * x, T * dst, const float min, const float max, const int k,
+                      const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
+    }
 }
 
 template<typename T>
-static void sin_sycl(const T *x, T *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { sin(x, dst, k, item_ct1); });
+static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        const int64_t j0 = (i / n) * o0 + (i % n);
+        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+        dst[i] = op_gelu(x[j0]) * g[j1];
+    }
 }
 
 template<typename T>
-static void cos_sycl(const T *x, T *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { cos(x, dst, k, item_ct1); });
+static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        const int64_t j0 = (i / n) * o0 + (i % n);
+        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+        dst[i] = op_relu(x[j0]) * g[j1];
+    }
 }
 
 template<typename T>
-static void leaky_relu_sycl(const T *x, T *dst, const int k,
-                                const float negative_slope,
-                                queue_ptr stream) {
-    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { leaky_relu(x, dst, k, negative_slope, item_ct1); });
+static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1)  {
+        const int64_t j0 = (i / n) * o0 + (i % n);
+        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+        dst[i] = op_silu(x[j0]) * g[j1];
+    }
 }
 
-template<typename T>
-static void sqr_sycl(const T *x, T *dst, const int k,
-                         queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
+namespace ggml_sycl_detail {
+static void acc_f32_sycl(const float *x, const float *y, float *dst,
+                         const int n_elements, const int ne10, const int ne11,
+                         const int ne12, const int nb1, const int nb2,
+                         const int offset, queue_ptr stream) {
+    int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
     sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { sqr(x, dst, k, item_ct1); });
+        sycl::nd_range<1>(sycl::range<1>(num_blocks) *
+                              sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
+                          sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
+        [=](sycl::nd_item<1> item_ct1) {
+            acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
+                    item_ct1);
+        });
 }
 
 template<typename T>
@@ -558,7 +405,7 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
                              const int ne12, const int ne13, const float sf0, const float sf1,
                              const float sf2, const float sf3, queue_ptr stream) {
     int dst_size = ne10 * ne11 * ne12 * ne13;
-    int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
+    int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
     sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
     sycl_parallel_for<1>(
         stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
@@ -570,7 +417,7 @@ template<typename T>
 static void pad_sycl(const T *x, T *dst, const int ne00,
                          const int ne01, const int ne02, const int ne0,
                          const int ne1, const int ne2, queue_ptr stream) {
-    int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
+    int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
     sycl::range<3> gridDim(ne2, ne1, num_blocks);
     sycl_parallel_for(stream,
                       sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
@@ -578,22 +425,11 @@ static void pad_sycl(const T *x, T *dst, const int ne00,
                       [=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
 }
 
-template<typename T>
-static void clamp_sycl(const T *x, T *dst, const float min,
-                           const float max, const int k,
-                           queue_ptr stream) {
-    const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
-    sycl_parallel_for(stream,
-                      sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
-                                        sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
-                      [=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
-}
-
-inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+template<typename KernelInvoker, typename... Args>
+static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
 #if defined (GGML_SYCL_F16)
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
     GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-
 #else
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -606,14 +442,14 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
         case GGML_TYPE_F16:
             {
                 auto data_pts = cast_data<sycl::half>(dst);
-                sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
                 break;
             }
 #endif
         case GGML_TYPE_F32:
             {
                 auto data_pts = cast_data<float>(dst);
-                sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
                 break;
             }
         default:
@@ -621,11 +457,11 @@ inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
     }
 }
 
-inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+template<typename KernelInvoker, typename... Args>
+static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
 #if defined (GGML_SYCL_F16)
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
     GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-
 #else
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
@@ -633,52 +469,66 @@ inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
     GGML_ASSERT(dst->src[0]->type == dst->type);
     dpct::queue_ptr main_stream = ctx.stream();
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;;
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+    void * src0_d = src0->data;
+    void * src1_d = src1 ? src1->data : src0->data;
+    const int64_t src0_o = src0->nb[1];
+    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+    void * dst_d = dst->data;
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+        GGML_ASSERT(src1->ne[0] == nc);
+        GGML_ASSERT(src0->type == src1->type);
     }
-}
-
-
-inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
     switch (dst->type) {
 #if defined (GGML_SYCL_F16)
         case GGML_TYPE_F16:
             {
-                auto data_pts = cast_data<sycl::half>(dst);
-                elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                sycl::half * src0_p = (sycl::half *) src0_d;
+                sycl::half * src1_p = (sycl::half *) src1_d;
+
+                    if (!src1) {
+                        src0_p += swapped ? nc : 0;
+                        src1_p += swapped ? 0 : nc;
+                    }
+                kernel_invoker(src0_p,
+                               src1_p,
+                               (sycl::half *) dst_d,
+                               ggml_nelements(dst),
+                               nc,
+                               src0_o / sizeof(sycl::half),
+                               src1_o / sizeof(sycl::half),
+                               main_stream,
+                               std::forward<Args>(args)...);
                 break;
             }
 #endif
         case GGML_TYPE_F32:
             {
-                auto data_pts = cast_data<float>(dst);
-                elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                float * src0_p = (float *) src0_d;
+                float * src1_p = (float *) src1_d;
+
+                    if (!src1) {
+                        src0_p += swapped ? nc : 0;
+                        src1_p += swapped ? 0 : nc;
+                    }
+
+                kernel_invoker(src0_p,
+                               src1_p,
+                               (float *) dst_d,
+                               ggml_nelements(dst),
+                               nc,
+                               src0_o / sizeof(float),
+                               src1_o / sizeof(float),
+                               main_stream,
+                               std::forward<Args>(args)...);
                 break;
             }
         default:
@@ -686,7 +536,8 @@ inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
     }
 }
 
-inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+template<typename KernelInvoker, typename... Args>
+static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
 #if defined (GGML_SYCL_F16)
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
     GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -695,52 +546,31 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 #endif
     GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                silu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
-}
 
-inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
     dpct::queue_ptr main_stream = ctx.stream();
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+
+    const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
+    const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
+    const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
+    const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
     switch (dst->type) {
 #if defined (GGML_SYCL_F16)
         case GGML_TYPE_F16:
             {
                 auto data_pts = cast_data<sycl::half>(dst);
-                gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
+                               (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
+                               main_stream, std::forward<Args>(args)...);
                 break;
             }
 #endif
         case GGML_TYPE_F32:
             {
                 auto data_pts = cast_data<float>(dst);
-                gelu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
+                               (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
+                               main_stream, std::forward<Args>(args)...);
                 break;
             }
         default:
@@ -748,7 +578,8 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst
     }
 }
 
-inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+template<typename KernelInvoker, typename... Args>
+static inline void dispatch_ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
 #if defined (GGML_SYCL_F16)
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
     GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
@@ -757,6 +588,7 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
     GGML_ASSERT(dst->type == GGML_TYPE_F32);
 #endif
     GGML_ASSERT(dst->src[0]->type == dst->type);
+    GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
     dpct::queue_ptr main_stream = ctx.stream();
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
     switch (dst->type) {
@@ -764,14 +596,16 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
         case GGML_TYPE_F16:
             {
                 auto data_pts = cast_data<sycl::half>(dst);
-                gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
+                               (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
                 break;
             }
 #endif
         case GGML_TYPE_F32:
             {
                 auto data_pts = cast_data<float>(dst);
-                gelu_quick_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->ne[0], (int)dst->src[0]->ne[1], (int)dst->src[0]->ne[2], (int)dst->ne[0],
+                               (int)dst->ne[1], (int)dst->ne[2], main_stream, std::forward<Args>(args)...);
                 break;
             }
         default:
@@ -779,593 +613,320 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
     }
 }
 
-inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
-}
-
-
-inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                tanh_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
-}
-
-inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
-}
+} // namespace ggml_sycl_detail
 
-inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
 
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
 
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                hardsigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_sgn_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                hardswish_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_abs_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                exp_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, 256);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
+                                  sycl::range<1>(256)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_elu_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                log_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_silu_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                sigmoid_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_gelu_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                sqrt_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_gelu_quick_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
-
-inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                sin_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+
+static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_gelu_erf_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                cos_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_tanh_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                step_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_relu_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                neg_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_hardsigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
+static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_hardswish_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
 
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    float negative_slope;
-    memcpy(&negative_slope, dst->op_params, sizeof(float));
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                leaky_relu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), negative_slope, main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_exp_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
- #if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                sqr_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
+static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_neg_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
 
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_step_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
 
-    const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
-    const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
-    const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
-    const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
-                        dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
-                        main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                upscale_sycl(data_pts.src, data_pts.dst, dst->src[0]->nb[0], dst->src[0]->nb[1], dst->src[0]->nb[2],
-                        dst->src[0]->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
-                        main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_sigmoid_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined (GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    GGML_ASSERT(dst->src[0]->ne[3] == 1 && dst->ne[3] == 1);  // just 3D tensors
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    switch (dst->type) {
-#if defined (GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
-                        dst->ne[1], dst->ne[2], main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                pad_sycl(data_pts.src, data_pts.dst, dst->src[0]->ne[0], dst->src[0]->ne[1], dst->src[0]->ne[2], dst->ne[0],
-                        dst->ne[1], dst->ne[2], main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
 }
 
-inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-#if defined(GGML_SYCL_F16)
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
-#else
+static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
 
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-#endif
-    GGML_ASSERT(dst->src[0]->type == dst->type);
-    dpct::queue_ptr main_stream = ctx.stream();
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    float min;
-    float max;
-    memcpy(&min, dst->op_params, sizeof(float));
-    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
 
-    switch (dst->type) {
-#if defined(GGML_SYCL_F16)
-        case GGML_TYPE_F16:
-            {
-                auto data_pts = cast_data<sycl::half>(dst);
-                clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-#endif
-        case GGML_TYPE_F32:
-            {
-                auto data_pts = cast_data<float>(dst);
-                clamp_sycl(data_pts.src, data_pts.dst, min, max, ggml_nelements(dst->src[0]), main_stream);
-                break;
-            }
-        default:
-            GGML_ABORT("GGML tensor type not supported!\n");
-    }
+static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
+            const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
+                });
+        }, negative_slope);
+}
+
+static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
+            const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
+                });
+        });
+}
+
+static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,
+           int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,
+           queue_ptr stream) {
+            ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);
+        });
 }
 
-inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_pad(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2,
+           queue_ptr stream) {
+            ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+        });
+}
 
+static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    float min_val;
+    float max_val;
+    memcpy(&min_val, dst->op_params, sizeof(float));
+    memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float));
+    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
+        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
+            const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
+            sycl_parallel_for(stream,
+                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
+                                  sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
+                [=](sycl::nd_item<1> item_ct1) {
+                    clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
+                });
+        }, min_val, max_val);
+}
+
+static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
@@ -1381,7 +942,40 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
     // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
     int offset = dst->op_params[3] / 4; // offset in bytes
 
-    acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
+    ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
+}
+
+static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
+        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
+            const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
+            sycl_parallel_for(main_stream,
+                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
+                gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
+            });
+        });
+}
+
+static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
+        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
+            const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
+            sycl_parallel_for(main_stream,
+                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
+                gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
+            });
+        });
+}
+
+static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
+        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
+            const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
+            sycl_parallel_for(main_stream,
+                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
+                gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
+            });
+        });
 }
 
 
@@ -1509,3 +1103,18 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
     ggml_sycl_op_elu(ctx, dst);
 }
+
+void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_geglu(ctx, dst);
+}
+
+void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_reglu(ctx, dst);
+}
+
+void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_swiglu(ctx, dst);
+}
index bd40113f0970560317caa4d4fc4bc506c5485495..86068b10129ec5406643b23057e5c87c24ab1e70 100644 (file)
@@ -3,27 +3,30 @@
 
 #include "common.hpp"
 #include "ggml.h"
-#include <limits.h>
+#include <limits> // For std::numeric_limits
 
 template <typename T>
 T neg_infinity() {
     return -std::numeric_limits<T>::infinity();
 }
 
-template<typename T>
+template<typename T_Dst, typename T_Src = T_Dst>
 struct typed_data {
-    const T * src;
-    T * dst;
+    const T_Src * src;
+    T_Dst * dst;
 };
 
-template<typename T>
-typed_data<T> cast_data(ggml_tensor * dst) {
+template<typename T_Dst, typename T_Src = T_Dst>
+typed_data<T_Dst, T_Src> cast_data(ggml_tensor * dst) {
     return {
-        /* .src = */ static_cast<const T *>(dst->src[0]->data),
-        /* .dst = */ static_cast<T *>(dst->data)
+        /* .src = */ static_cast<const T_Src *>(dst->src[0]->data),
+        /* .dst = */ static_cast<T_Dst *>(dst->data)
     };
 }
 
+const float GELU_QUICK_COEF = -1.702f;
+
+
 void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
@@ -73,5 +76,9 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
-#endif // GGML_SYCL_ELEMENTWISE_HPP
 
+void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+#endif // GGML_SYCL_ELEMENTWISE_HPP
index 9cb36ae99e7f57ff4706037e524df7f21e2bfe2f..ae5e062572e324a8c54eab4abe482742d0d6fc0c 100644 (file)
@@ -3676,6 +3676,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
                     return false;
             }
             break;
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(dst)) {
+                case GGML_GLU_OP_REGLU:
+                    ggml_sycl_reglu(ctx, dst);
+                    break;
+                case GGML_GLU_OP_GEGLU:
+                    ggml_sycl_geglu(ctx, dst);
+                    break;
+                case GGML_GLU_OP_SWIGLU:
+                    ggml_sycl_swiglu(ctx, dst);
+                    break;
+                default:
+                    return false;
+            }
+            break;
         case GGML_OP_NORM:
             ggml_sycl_norm(ctx, dst);
             break;
@@ -4212,6 +4227,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 default:
                     return false;
             }
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(op)) {
+                case GGML_GLU_OP_REGLU:
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_SWIGLU:
+                    return ggml_is_contiguous_1(op->src[0]);
+                default:
+                    return false;
+            }
+            break;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             {
index aebcc03915f5f71145d66fe159902485ef3df8d7..4696f1fe46e3bb2531e0afb810120d75d306e828 100644 (file)
@@ -437,6 +437,10 @@ struct vk_device_struct {
     vk_pipeline pipeline_tanh[2];
     vk_pipeline pipeline_sigmoid[2];
 
+    vk_pipeline pipeline_geglu[2];
+    vk_pipeline pipeline_reglu[2];
+    vk_pipeline pipeline_swiglu[2];
+
     vk_pipeline pipeline_leaky_relu_f32;
     vk_pipeline pipeline_silu_back_f32;
     vk_pipeline pipeline_diag_mask_inf_f32;
@@ -661,6 +665,13 @@ struct vk_op_push_constants {
     float param2;
 };
 
+struct vk_op_glu_push_constants {
+    uint32_t N;
+    uint32_t ne00;
+    uint32_t ne20;
+    uint32_t mode;  // 0: default, 1: swapped, 2: split
+};
+
 struct vk_op_unary_push_constants {
     uint32_t ne;
     uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2757,6 +2768,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
     CREATE_UNARY(sigmoid)
 #undef CREATE_UNARY
 
+#define CREATE_GLU(name)  \
+    ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);  \
+    ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
+
+    CREATE_GLU(geglu)
+    CREATE_GLU(reglu)
+    CREATE_GLU(swiglu)
+#undef CREATE_GLU
+
     ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
@@ -6473,6 +6493,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 break;
         }
         return nullptr;
+    case GGML_OP_GLU:
+        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
+            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
+            (src0->type != dst->type)) {
+            return nullptr;
+        }
+
+        switch (ggml_get_glu_op(dst)) {
+            case GGML_GLU_OP_GEGLU:
+                return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
+            case GGML_GLU_OP_REGLU:
+                return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
+            case GGML_GLU_OP_SWIGLU:
+                return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
+            default:
+                break;
+        }
+        return nullptr;
     case GGML_OP_DIAG_MASK_INF:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_diag_mask_inf_f32;
@@ -6933,6 +6971,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_CONCAT:
     case GGML_OP_UPSCALE:
     case GGML_OP_UNARY:
+    case GGML_OP_GLU:
     case GGML_OP_CONV_2D_DW:
         {
             uint32_t ne = ggml_nelements(dst);
@@ -6973,7 +7012,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         }
     }
 
-    if (op == GGML_OP_SOFT_MAX) {
+    if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
         // Empty src1 is possible in soft_max, but the shader needs a buffer
         vk_subbuffer subbuf_y;
         if (use_src1) {
@@ -7566,6 +7605,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
 }
 
+static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+    const bool swapped = (bool)dst->op_params[1];
+    const bool split = src1 != nullptr;
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    if (!split) {
+        GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
+    } else {
+        GGML_ASSERT(src0->ne[0] == src1->ne[0]);
+        GGML_ASSERT(src0->ne[0] == dst->ne[0]);
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
+
+    ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
+}
+
 static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     int32_t * op_params = (int32_t *)dst->op_params;
     ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
@@ -8778,6 +8836,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
             return false;
         }
         break;
+    case GGML_OP_GLU:
+        switch (ggml_get_glu_op(node)) {
+        case GGML_GLU_OP_GEGLU:
+        case GGML_GLU_OP_REGLU:
+        case GGML_GLU_OP_SWIGLU:
+            break;
+        default:
+            return false;
+        }
+        break;
     case GGML_OP_REPEAT:
     case GGML_OP_REPEAT_BACK:
     case GGML_OP_GET_ROWS:
@@ -8870,6 +8938,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_OP_RMS_NORM_BACK:
         case GGML_OP_L2_NORM:
         case GGML_OP_UNARY:
+        case GGML_OP_GLU:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_SOFT_MAX_BACK:
@@ -9013,6 +9082,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
             return false;
         }
         break;
+    case GGML_OP_GLU:
+        switch (ggml_get_glu_op(node)) {
+        case GGML_GLU_OP_GEGLU:
+        case GGML_GLU_OP_REGLU:
+        case GGML_GLU_OP_SWIGLU:
+            ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
+            break;
+        default:
+            return false;
+        }
+        break;
     case GGML_OP_DIAG_MASK_INF:
         ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
 
@@ -9138,8 +9218,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         if (!ok) {
             if (node->op == GGML_OP_UNARY) {
                 std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
-            }
-            else {
+            } else if (node->op == GGML_OP_GLU) {
+                std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
+            } else {
                 std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
             }
         }
@@ -9218,6 +9299,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
             return false;
         }
         break;
+    case GGML_OP_GLU:
+        switch (ggml_get_glu_op(tensor)) {
+        case GGML_GLU_OP_GEGLU:
+        case GGML_GLU_OP_REGLU:
+        case GGML_GLU_OP_SWIGLU:
+            buf = tensor->buffer;
+            break;
+        default:
+            return false;
+        }
+        break;
     case GGML_OP_MUL_MAT:
     case GGML_OP_MUL_MAT_ID:
     case GGML_OP_FLASH_ATTN_EXT:
@@ -10016,6 +10108,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     return false;
             }
             break;
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(op)) {
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_REGLU:
+                case GGML_GLU_OP_SWIGLU:
+                    return ggml_is_contiguous(op->src[0]) &&
+                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
+                           (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
+                           (op->src[0]->type == op->type);
+                default:
+                    return false;
+            }
+            break;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             {
@@ -10746,6 +10851,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
             std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
             GGML_ABORT("fatal error");
         }
+    } else if (tensor->op == GGML_OP_GLU) {
+        if (src_clone[1] == nullptr) {
+            tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
+        } else {
+            tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
+        }
     } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
         if (src1 == nullptr) {
             tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp
new file mode 100644 (file)
index 0000000..f4268ed
--- /dev/null
@@ -0,0 +1,13 @@
+#version 450
+
+#include "glu_head.comp"
+
+const float GELU_COEF_A    = 0.044715f;
+const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+float op(float a, float b) {
+    const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
+    return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
+}
+
+#include "glu_main.comp"
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp
new file mode 100644 (file)
index 0000000..41a2988
--- /dev/null
@@ -0,0 +1,15 @@
+#extension GL_EXT_shader_16bit_storage : require
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+layout (push_constant) uniform parameter
+{
+    uint N;
+    uint ne00;
+    uint ne20;
+    uint mode;
+} p;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp
new file mode 100644 (file)
index 0000000..85cf65a
--- /dev/null
@@ -0,0 +1,29 @@
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.N) {
+        return;
+    }
+
+    const uint row = i / p.ne20;
+    const uint col = i - row * p.ne20;
+
+    if (p.mode == 0) {
+        // Default
+        const uint offset = p.ne00 / 2;
+        const uint idx = row * p.ne00 + col;
+
+        data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
+    } else if (p.mode == 1) {
+        // Swapped
+        const uint offset = p.ne00 / 2;
+        const uint idx = row * p.ne00 + col;
+
+        data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
+    } else {
+        // Split
+        const uint idx = row * p.ne00 + col;
+
+        data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
+    }
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp
new file mode 100644 (file)
index 0000000..0073d8f
--- /dev/null
@@ -0,0 +1,9 @@
+#version 450
+
+#include "glu_head.comp"
+
+float op(float a, float b) {
+    return max(a, 0.0f) * b;
+}
+
+#include "glu_main.comp"
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp
new file mode 100644 (file)
index 0000000..a28e7c6
--- /dev/null
@@ -0,0 +1,9 @@
+#version 450
+
+#include "glu_head.comp"
+
+float op(float a, float b) {
+    return a / (1.0f + exp(-a)) * b;
+}
+
+#include "glu_main.comp"
index a207b98c60e510ef55c1a448bf4fe888bb03e90c..23fc50bf29503b27b58cc9939b0484703093afde 100644 (file)
@@ -585,6 +585,13 @@ void process_shaders() {
     string_to_spv("sigmoid_f16",    "sigmoid.comp",     {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("sigmoid_f32",    "sigmoid.comp",     {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 
+    string_to_spv("geglu_f16",      "geglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("geglu_f32",      "geglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("reglu_f16",      "reglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("reglu_f32",      "reglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("swiglu_f16",     "swiglu.comp",      {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("swiglu_f32",     "swiglu.comp",      {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+
     string_to_spv("leaky_relu_f32", "leaky_relu.comp",  {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     string_to_spv("silu_back_f32",  "silu_back.comp",   {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 
index e70d9d57b3bbea428a39f3a7afc85579e782cd8a..fa57b60ef7d4efa3dcd62dd364659b7690dd5c41 100644 (file)
@@ -982,9 +982,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS",
     "CROSS_ENTROPY_LOSS_BACK",
     "OPT_STEP_ADAMW",
+
+    "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
+static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1079,9 +1081,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss(x,y)",
     "cross_entropy_loss_back(x,y)",
     "adamw(x)",
+
+    "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
+static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1107,6 +1111,15 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
 static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
 
 
+static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
+    "REGLU",
+    "GEGLU",
+    "SWIGLU",
+};
+
+static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
+
+
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
 
@@ -1209,11 +1222,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
     return GGML_UNARY_OP_NAME[op];
 }
 
+const char * ggml_glu_op_name(enum ggml_glu_op op) {
+    return GGML_GLU_OP_NAME[op];
+}
+
 const char * ggml_op_desc(const struct ggml_tensor * t) {
     if (t->op == GGML_OP_UNARY) {
         enum ggml_unary_op uop = ggml_get_unary_op(t);
         return ggml_unary_op_name(uop);
     }
+    if (t->op == GGML_OP_GLU) {
+        enum ggml_glu_op gop = ggml_get_glu_op(t);
+        return ggml_glu_op_name(gop);
+    }
     return ggml_op_name(t->op);
 }
 
@@ -1730,6 +1751,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
     return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
 }
 
+enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->op == GGML_OP_GLU);
+    return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
+}
+
 const char * ggml_get_name(const struct ggml_tensor * tensor) {
     return tensor->name;
 }
@@ -2609,6 +2635,114 @@ struct ggml_tensor * ggml_exp_inplace(
     return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
 }
 
+// ggml_glu
+
+static struct ggml_tensor * ggml_glu_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        enum ggml_glu_op      op,
+        bool                  swapped) {
+    GGML_ASSERT(ggml_is_contiguous_1(a));
+
+    if (b) {
+        GGML_ASSERT(ggml_is_contiguous_1(b));
+        GGML_ASSERT(ggml_are_same_shape(a, b));
+        GGML_ASSERT(a->type == b->type);
+    }
+
+    int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
+
+    ggml_set_op_params_i32(result, 0, (int32_t) op);
+    ggml_set_op_params_i32(result, 1, (int32_t) swapped);
+
+    result->op     = GGML_OP_GLU;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_glu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        enum ggml_glu_op      op,
+        bool                  swapped) {
+    return ggml_glu_impl(ctx, a, NULL, op, swapped);
+}
+
+struct ggml_tensor * ggml_glu_split(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        enum ggml_glu_op      op) {
+    return ggml_glu_impl(ctx, a, b, op, false);
+}
+
+// ggml_reglu
+
+struct ggml_tensor * ggml_reglu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
+}
+
+struct ggml_tensor * ggml_reglu_swapped(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
+}
+
+struct ggml_tensor * ggml_reglu_split(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
+}
+
+// ggml_geglu
+
+struct ggml_tensor * ggml_geglu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
+}
+
+struct ggml_tensor * ggml_geglu_swapped(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
+}
+
+struct ggml_tensor * ggml_geglu_split(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
+}
+
+// ggml_swiglu
+
+struct ggml_tensor * ggml_swiglu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
+}
+
+struct ggml_tensor * ggml_swiglu_swapped(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
+}
+
+struct ggml_tensor * ggml_swiglu_split(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
+}
+
 // ggml_norm
 
 static struct ggml_tensor * ggml_norm_impl(