]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : implement GEGLU_ERF and GEGLU_QUICK ops (llama/14445)
authorSigbjørn Skjæret <redacted>
Thu, 3 Jul 2025 21:07:22 +0000 (23:07 +0200)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
19 files changed:
include/ggml.h
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ops.cpp
src/ggml-cpu/vec.h
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/unary.cu
src/ggml-cuda/unary.cuh
src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal
src/ggml-opencl/ggml-opencl.cpp
src/ggml-opencl/kernels/glu.cl
src/ggml-sycl/element_wise.cpp
src/ggml-sycl/element_wise.hpp
src/ggml-sycl/ggml-sycl.cpp
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/geglu_erf.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/geglu_quick.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
src/ggml.c

index d25bfcc9f45f695a7df36a9595dfbdb0dd70c91e..74068f2be6061bb19ddef3cee8c6497f1395ceeb 100644 (file)
@@ -557,6 +557,8 @@ extern "C" {
         GGML_GLU_OP_REGLU,
         GGML_GLU_OP_GEGLU,
         GGML_GLU_OP_SWIGLU,
+        GGML_GLU_OP_GEGLU_ERF,
+        GGML_GLU_OP_GEGLU_QUICK,
 
         GGML_GLU_OP_COUNT,
     };
@@ -1144,6 +1146,22 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_geglu_erf(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_geglu_quick(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // A: n columns, r rows,
     // B: n columns, r rows,
     GGML_API struct ggml_tensor * ggml_glu_split(
@@ -1167,6 +1185,16 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_geglu_erf_split(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_geglu_quick_split(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // normalize along rows
     GGML_API struct ggml_tensor * ggml_norm(
             struct ggml_context * ctx,
index 11ff228f07a445d12b3fafea5b8f1d41f18814ac..c5271b77572289a9dd670aac7014112882de9409 100644 (file)
@@ -2172,6 +2172,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                case GGML_GLU_OP_GEGLU_QUICK:
                     {
                         n_tasks = n_threads;
                     } break;
index 0fb2c08b5bcb5ffc15676a42ad6c95a354bdaca1..aaeee614ab993b33438aa2291c7889fc55417bac 100644 (file)
@@ -3614,6 +3614,292 @@ static void ggml_compute_forward_swiglu(
     }
 }
 
+// ggml_compute_forward_geglu_erf
+
+static void ggml_compute_forward_geglu_erf_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * src0_p = (float *) (src0_d + i1*src0_o);
+        float * src1_p = (float *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_geglu_erf_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
+        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_geglu_erf(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_geglu_erf_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_geglu_erf_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_geglu_quick
+
+static void ggml_compute_forward_geglu_quick_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * src0_p = (float *) (src0_d + i1*src0_o);
+        float * src1_p = (float *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_geglu_quick_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
+        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_geglu_quick(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_geglu_quick_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_geglu_quick_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_norm
 
 static void ggml_compute_forward_norm_f32(
@@ -8779,6 +9065,14 @@ void ggml_compute_forward_glu(
             {
                 ggml_compute_forward_swiglu(params, dst);
             } break;
+        case GGML_GLU_OP_GEGLU_ERF:
+            {
+                ggml_compute_forward_geglu_erf(params, dst);
+            } break;
+        case GGML_GLU_OP_GEGLU_QUICK:
+            {
+                ggml_compute_forward_geglu_quick(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
index c432c990818226c3d3cd357cab1453c6cdf7e737..1f5857a23e35c06aac9934c25c09e26047034332 100644 (file)
@@ -959,6 +959,46 @@ inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_
     }
 }
 
+inline static void ggml_vec_geglu_erf_f32(const int n, float * y, const float * x, const float * g) {
+    for (int i = 0; i < n; ++i) {
+        float xi = x[i];
+        y[i] = 0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * g[i];
+    }
+}
+
+inline static void ggml_vec_geglu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
+    for (int i = 0; i < n; ++i) {
+        float xi = GGML_CPU_FP16_TO_FP32(x[i]);
+        float gi = GGML_CPU_FP16_TO_FP32(g[i]);
+        y[i] = GGML_CPU_FP32_TO_FP16(0.5f * xi * (1.0f + erff(xi*SQRT_2_INV)) * gi);
+    }
+}
+
+#ifdef GGML_GELU_QUICK_FP16
+inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_CPU_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]) * g[i];
+    }
+}
+#else
+inline static void ggml_vec_geglu_quick_f32(const int n, float * y, const float * x, const float * g) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_quick_f32(x[i]) * g[i];
+    }
+}
+#endif
+
+inline static void ggml_vec_geglu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        float v = GGML_CPU_FP16_TO_FP32(g[i]);
+        y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(ggml_table_gelu_quick_f16[i16[i]]) * v);
+    }
+}
+
 inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     ggml_float sum = 0.0;
index 1c04bba52e88bf2419a3d10854086147c88cf06c..af5ad1ed52cdcbe114ed6792a313dca043c3bbca 100644 (file)
@@ -2314,6 +2314,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_GLU_OP_SWIGLU:
                     ggml_cuda_op_swiglu(ctx, dst);
                     break;
+                case GGML_GLU_OP_GEGLU_ERF:
+                    ggml_cuda_op_geglu_erf(ctx, dst);
+                    break;
+                case GGML_GLU_OP_GEGLU_QUICK:
+                    ggml_cuda_op_geglu_quick(ctx, dst);
+                    break;
                 default:
                     return false;
             }
@@ -3116,6 +3122,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]);
                 default:
                     return false;
index ba3c0f13762b0abb8a62f8bcdc2b58ef83278454..f9c7b83c40d1bb83786ff04be0c1299655628f28 100644 (file)
@@ -285,6 +285,14 @@ void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
 }
 
+void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
+}
+
+void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
+}
+
 /* silu_back */
 
 static __device__ __forceinline__ float op_silu_back(float grad, float x) {
index 9094f1d0bad3740a18618caee851cecc911fe364..289d690e5cff6c0b63fed83e5da42bd846a52104 100644 (file)
@@ -64,3 +64,7 @@ void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 5e5467e88a1ab20122b7bc0395e4a0f6a6b850e3..40fc315e82fd15d96019ee7cb15edb8fc180ffba 100644 (file)
@@ -530,6 +530,8 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_REGLU,
     GGML_METAL_KERNEL_TYPE_GEGLU,
     GGML_METAL_KERNEL_TYPE_SWIGLU,
+    GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
+    GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
     GGML_METAL_KERNEL_TYPE_MEAN,
     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@@ -1510,6 +1512,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU,                           reglu,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU,                           geglu,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU,                          swiglu,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF,                       geglu_erf,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,                     geglu_quick,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
@@ -1693,6 +1697,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                default:
                     return false;
@@ -2456,6 +2462,12 @@ static bool ggml_metal_encode_node(
                     case GGML_GLU_OP_SWIGLU:
                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
                         break;
+                    case GGML_GLU_OP_GEGLU_ERF:
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
+                        break;
+                    case GGML_GLU_OP_GEGLU_QUICK:
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_QUICK].pipeline;
+                        break;
                     default:
                         GGML_ABORT("fatal error");
                 }
index ebde005f33e028013fa737abfdb99cfbe9d4256a..dc7a0af2769dccfe158f01db9f210129b7165c95 100644 (file)
@@ -1258,6 +1258,50 @@ kernel void kernel_swiglu(
     }
 }
 
+kernel void kernel_geglu_erf(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant ggml_metal_kargs_glu & args,
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
+    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
+    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
+
+    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
+        const float x0 = src0_row[i0];
+        const float x1 = src1_row[i0];
+
+        const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
+
+        dst_row[i0] = gelu_erf*x1;
+    }
+}
+
+kernel void kernel_geglu_quick(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant ggml_metal_kargs_glu & args,
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
+    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
+    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
+
+    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
+        const float x0 = src0_row[i0];
+        const float x1 = src1_row[i0];
+
+        const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
+
+        dst_row[i0] = gelu_quick*x1;
+    }
+}
+
 template <bool norm>
 kernel void kernel_sum_rows(
         constant ggml_metal_kargs_sum_rows & args,
index 2450100b43c95120273c92bf9d69edc2d64c34b1..970dd3f67f5f3496ad5b47a615932c7f72b82b88 100644 (file)
@@ -402,8 +402,8 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_relu;
     cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
     cl_kernel kernel_clamp;
-    cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu,
-              kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16;
+    cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
+              kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
     cl_kernel kernel_norm;
     cl_kernel kernel_rms_norm;
     cl_kernel kernel_group_norm;
@@ -753,12 +753,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         backend_ctx->program_glu =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_geglu      = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
-        CL_CHECK((backend_ctx->kernel_reglu      = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
-        CL_CHECK((backend_ctx->kernel_swiglu     = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
-        CL_CHECK((backend_ctx->kernel_geglu_f16  = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
-        CL_CHECK((backend_ctx->kernel_reglu_f16  = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
-        CL_CHECK((backend_ctx->kernel_swiglu_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_geglu           = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
+        CL_CHECK((backend_ctx->kernel_reglu           = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
+        CL_CHECK((backend_ctx->kernel_swiglu          = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
+        CL_CHECK((backend_ctx->kernel_geglu_erf       = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
+        CL_CHECK((backend_ctx->kernel_geglu_quick     = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
+        CL_CHECK((backend_ctx->kernel_geglu_f16       = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_reglu_f16       = clCreateKernel(backend_ctx->program_glu, "kernel_reglu_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_swiglu_f16      = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_geglu_erf_f16   = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_geglu_quick_f16 = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick_f16", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -2277,6 +2281,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
                 default:
                     return false;
@@ -6254,6 +6260,20 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
                 kernel = backend_ctx->kernel_swiglu_f16;
             }
             break;
+        case GGML_GLU_OP_GEGLU_ERF:
+            if (dst->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_geglu_erf;
+            } else {
+                kernel = backend_ctx->kernel_geglu_erf_f16;
+            }
+            break;
+        case GGML_GLU_OP_GEGLU_QUICK:
+            if (dst->type == GGML_TYPE_F32) {
+                kernel = backend_ctx->kernel_geglu_quick;
+            } else {
+                kernel = backend_ctx->kernel_geglu_quick_f16;
+            }
+            break;
         default:
             GGML_ABORT("Unsupported glu op");
     }
index ba861d8b18f0cbf06213ab5f2d751c355b8940fb..7cca16e6a9e7e5f271c2f328fa1fc9d730bc163a 100644 (file)
@@ -1,7 +1,9 @@
 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 
 #define GELU_COEF_A     0.044715f
+#define GELU_QUICK_COEF -1.702f
 #define SQRT_2_OVER_PI  0.79788456080286535587989211986876f
+#define SQRT_2_INV      0.70710678118654752440084436210484f
 
 //------------------------------------------------------------------------------
 // geglu
@@ -199,3 +201,137 @@ kernel void kernel_swiglu_f16(
         dst_row[i0] = silu*x1;
     }
 }
+
+//------------------------------------------------------------------------------
+// geglu_erf
+//------------------------------------------------------------------------------
+kernel void kernel_geglu_erf(
+    global char * src0,
+    ulong  offset0,
+    global char * src1,
+    ulong  offset1,
+    global char * dst,
+    ulong  offsetd,
+    ulong nb01,
+    ulong nb11,
+    int ne0,
+    ulong nb1,
+    int ne00_off,
+    int ne10_off
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst  = (global char*)((global char*)dst  + offsetd);
+
+    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const float x0 = src0_row[i0];
+        const float x1 = src1_row[i0];
+
+        const float gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
+
+        dst_row[i0] = gelu_erf*x1;
+    }
+}
+
+kernel void kernel_geglu_erf_f16(
+    global char * src0,
+    ulong  offset0,
+    global char * src1,
+    ulong  offset1,
+    global char * dst,
+    ulong  offsetd,
+    ulong nb01,
+    ulong nb11,
+    int ne0,
+    ulong nb1,
+    int ne00_off,
+    int ne10_off
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst  = (global char*)((global char*)dst  + offsetd);
+
+    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const half x0 = src0_row[i0];
+        const half x1 = src1_row[i0];
+
+        const half gelu_erf = 0.5f*x0*(1.0f + erf(x0*SQRT_2_INV));
+
+        dst_row[i0] = gelu_erf*x1;
+    }
+}
+
+//------------------------------------------------------------------------------
+// geglu_quick
+//------------------------------------------------------------------------------
+kernel void kernel_geglu_quick(
+    global char * src0,
+    ulong  offset0,
+    global char * src1,
+    ulong  offset1,
+    global char * dst,
+    ulong  offsetd,
+    ulong nb01,
+    ulong nb11,
+    int ne0,
+    ulong nb1,
+    int ne00_off,
+    int ne10_off
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst  = (global char*)((global char*)dst  + offsetd);
+
+    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const float x0 = src0_row[i0];
+        const float x1 = src1_row[i0];
+
+        const float gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
+
+        dst_row[i0] = gelu_quick*x1;
+    }
+}
+
+kernel void kernel_geglu_quick_f16(
+    global char * src0,
+    ulong  offset0,
+    global char * src1,
+    ulong  offset1,
+    global char * dst,
+    ulong  offsetd,
+    ulong nb01,
+    ulong nb11,
+    int ne0,
+    ulong nb1,
+    int ne00_off,
+    int ne10_off
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst  = (global char*)((global char*)dst  + offsetd);
+
+    global half * src0_row = (global half *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+    global half * src1_row = (global half *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+    global half * dst_row  = (global half *) ((global char *) dst  + get_group_id(0)*nb1);
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const half x0 = src0_row[i0];
+        const half x1 = src1_row[i0];
+
+        const half gelu_quick = x0*(1.0f/(1.0f + exp(GELU_QUICK_COEF*x0)));
+
+        dst_row[i0] = gelu_quick*x1;
+    }
+}
index c7788bdb6bf8c4618fade66872175621625d105d..0363b06a3ec9bc3bc1476ce856161e6beeb96df6 100644 (file)
@@ -383,6 +383,24 @@ static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint6
     }
 }
 
+template<typename T>
+static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        const int64_t j0 = (i / n) * o0 + (i % n);
+        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+        dst[i] = op_gelu_erf(x[j0]) * g[j1];
+    }
+}
+
+template<typename T>
+static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
+    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
+        const int64_t j0 = (i / n) * o0 + (i % n);
+        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+        dst[i] = op_gelu_quick(x[j0]) * g[j1];
+    }
+}
+
 namespace ggml_sycl_detail {
 static void acc_f32_sycl(const float *x, const float *y, float *dst,
                          const int n_elements, const int ne10, const int ne11,
@@ -978,6 +996,28 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
         });
 }
 
+static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
+        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
+            const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
+            sycl_parallel_for(main_stream,
+                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
+                gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
+            });
+        });
+}
+
+static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
+        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
+            const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
+            sycl_parallel_for(main_stream,
+                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
+                gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
+            });
+        });
+}
+
 
 void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
@@ -1118,3 +1158,13 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
     ggml_sycl_op_swiglu(ctx, dst);
 }
+
+void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_geglu_erf(ctx, dst);
+}
+
+void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_geglu_quick(ctx, dst);
+}
index 86068b10129ec5406643b23057e5c87c24ab1e70..50749e87d783e140463a9ceeed97b9286751431b 100644 (file)
@@ -80,5 +80,7 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 #endif // GGML_SYCL_ELEMENTWISE_HPP
index b063a698ded2bd2d74857521bb262d8433a336a6..21c81e99a19aa560ef4b5c7f2b87224c2d556225 100644 (file)
@@ -3687,6 +3687,12 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
                 case GGML_GLU_OP_SWIGLU:
                     ggml_sycl_swiglu(ctx, dst);
                     break;
+                case GGML_GLU_OP_GEGLU_ERF:
+                    ggml_sycl_geglu_erf(ctx, dst);
+                    break;
+                case GGML_GLU_OP_GEGLU_QUICK:
+                    ggml_sycl_geglu_quick(ctx, dst);
+                    break;
                 default:
                     return false;
             }
@@ -4232,6 +4238,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]);
                 default:
                     return false;
index c0032cba218cfbc414db0740cc03d52195ed6012..22a34a433568f6a2e6a69612a5e553f224f8f28e 100644 (file)
@@ -456,6 +456,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_geglu[2];
     vk_pipeline pipeline_reglu[2];
     vk_pipeline pipeline_swiglu[2];
+    vk_pipeline pipeline_geglu_erf[2];
+    vk_pipeline pipeline_geglu_quick[2];
 
     vk_pipeline pipeline_leaky_relu_f32;
     vk_pipeline pipeline_silu_back_f32;
@@ -2821,6 +2823,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
     CREATE_GLU(geglu)
     CREATE_GLU(reglu)
     CREATE_GLU(swiglu)
+    CREATE_GLU(geglu_erf)
+    CREATE_GLU(geglu_quick)
 #undef CREATE_GLU
 
     ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -6575,6 +6579,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
             case GGML_GLU_OP_SWIGLU:
                 return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
+            case GGML_GLU_OP_GEGLU_ERF:
+                return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
+            case GGML_GLU_OP_GEGLU_QUICK:
+                return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
             default:
                 break;
         }
@@ -8919,6 +8927,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_GLU_OP_GEGLU:
         case GGML_GLU_OP_REGLU:
         case GGML_GLU_OP_SWIGLU:
+        case GGML_GLU_OP_GEGLU_ERF:
+        case GGML_GLU_OP_GEGLU_QUICK:
             break;
         default:
             return false;
@@ -9166,6 +9176,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_GLU_OP_GEGLU:
         case GGML_GLU_OP_REGLU:
         case GGML_GLU_OP_SWIGLU:
+        case GGML_GLU_OP_GEGLU_ERF:
+        case GGML_GLU_OP_GEGLU_QUICK:
             ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
             break;
         default:
@@ -9384,6 +9396,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
         case GGML_GLU_OP_GEGLU:
         case GGML_GLU_OP_REGLU:
         case GGML_GLU_OP_SWIGLU:
+        case GGML_GLU_OP_GEGLU_ERF:
+        case GGML_GLU_OP_GEGLU_QUICK:
             buf = tensor->buffer;
             break;
         default:
@@ -10194,6 +10208,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous(op->src[0]) &&
                            (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
                            (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
diff --git a/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp
new file mode 100644 (file)
index 0000000..cbd4cb3
--- /dev/null
@@ -0,0 +1,27 @@
+#version 450
+
+#include "glu_head.comp"
+
+// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+// ref: https://www.johndcook.com/blog/python_erf/
+const float p_erf  = 0.3275911f;
+const float a1_erf = 0.254829592f;
+const float a2_erf = -0.284496736f;
+const float a3_erf = 1.421413741f;
+const float a4_erf = -1.453152027f;
+const float a5_erf = 1.061405429f;
+
+const float SQRT_2_INV = 0.70710678118654752440084436210484f;
+
+float op(float a, float b) {
+    const float a_div_sqr2 = a * SQRT_2_INV;
+    const float sign_x = sign(a_div_sqr2);
+    const float x = abs(a_div_sqr2);
+    const float t = 1.0f / (1.0f + p_erf * x);
+    const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+    const float erf_approx = sign_x * y;
+
+    return 0.5f * a * (1.0f + erf_approx) * b;
+}
+
+#include "glu_main.comp"
diff --git a/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp
new file mode 100644 (file)
index 0000000..3a2a689
--- /dev/null
@@ -0,0 +1,11 @@
+#version 450
+
+#include "glu_head.comp"
+
+const float GELU_QUICK_COEF = -1.702f;
+
+float op(float a, float b) {
+    return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
+}
+
+#include "glu_main.comp"
index 297a2a77119ea4f815b45aeb6844543ff7334c0a..2698522ed7101fd610459fa8690b0c8ab2b918ed 100644 (file)
@@ -593,6 +593,10 @@ void process_shaders() {
     string_to_spv("reglu_f32",      "reglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("swiglu_f16",     "swiglu.comp",      {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("swiglu_f32",     "swiglu.comp",      {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("geglu_erf_f16",  "geglu_erf.comp",   {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("geglu_erf_f32",  "geglu_erf.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 
     string_to_spv("leaky_relu_f32", "leaky_relu.comp",  {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     string_to_spv("silu_back_f32",  "silu_back.comp",   {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
index 68768842904602657ffe8d8158805d54b8f35a71..e2d9d616a596cc86bfa1adbc00ad9e46465d792f 100644 (file)
@@ -1132,9 +1132,11 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
     "REGLU",
     "GEGLU",
     "SWIGLU",
+    "GEGLU_ERF",
+    "GEGLU_QUICK",
 };
 
-static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
+static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
 
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2760,6 +2762,48 @@ struct ggml_tensor * ggml_swiglu_split(
     return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
 }
 
+// ggml_geglu_erf
+
+struct ggml_tensor * ggml_geglu_erf(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, false);
+}
+
+struct ggml_tensor * ggml_geglu_erf_swapped(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_ERF, true);
+}
+
+struct ggml_tensor * ggml_geglu_erf_split(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_ERF, false);
+}
+
+// ggml_geglu_quick
+
+struct ggml_tensor * ggml_geglu_quick(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, false);
+}
+
+struct ggml_tensor * ggml_geglu_quick_swapped(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU_QUICK, true);
+}
+
+struct ggml_tensor * ggml_geglu_quick_split(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
+}
+
 // ggml_norm
 
 static struct ggml_tensor * ggml_norm_impl(