]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : add ggml_gelu_erf() (llama/13667)
authorXuan-Son Nguyen <redacted>
Wed, 21 May 2025 14:26:33 +0000 (16:26 +0200)
committerGeorgi Gerganov <redacted>
Tue, 27 May 2025 15:03:00 +0000 (18:03 +0300)
* ggml : add ggml_gelu_na (not approximated)

* fix naming order

* rename na --> erf

* apply review suggesions

* revert naming order

ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/vec.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
ggml/src/ggml.c

index e91dedf14a1cbbcf7d554e4b12d95604ec8f168c..c81ff03fee810b0c6108b70a9996411dd21b5c4c 100644 (file)
@@ -528,14 +528,15 @@ extern "C" {
         GGML_UNARY_OP_STEP,
         GGML_UNARY_OP_TANH,
         GGML_UNARY_OP_ELU,
-        GGML_UNARY_OP_RELU,
         GGML_UNARY_OP_SIGMOID,
         GGML_UNARY_OP_GELU,
+        GGML_UNARY_OP_GELU_ERF,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_SILU,
         GGML_UNARY_OP_HARDSWISH,
         GGML_UNARY_OP_HARDSIGMOID,
         GGML_UNARY_OP_EXP,
+        GGML_UNARY_OP_RELU,
 
         GGML_UNARY_OP_COUNT,
     };
@@ -1024,6 +1025,16 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // GELU using erf (error function) when possible
+    // some backends may fallback to approximation based on Abramowitz and Stegun formula
+    GGML_API struct ggml_tensor * ggml_gelu_erf(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_gelu_quick(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
index 133b50606bcd1dd9779b7524ce78924dbee9548b..46f75ad97cd6161fe71aa59b79e91ef6ce0d8784 100644 (file)
@@ -2202,6 +2202,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                     } break;
 
                 case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_ERF:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                     {
index 955fec59a6e93e9a0d793c6a2bcddd94b7c00a64..26501b7118b95c40eaa5128fa569b21cc133903c 100644 (file)
@@ -2691,6 +2691,109 @@ static void ggml_compute_forward_gelu(
     }
 }
 
+// ggml_compute_forward_gelu_erf
+
+static void ggml_compute_forward_gelu_erf_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_erf_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu_erf_f16(
+    const ggml_compute_params * params,
+    ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(dst));
+    assert(ggml_are_same_shape(src0, dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_erf_f16(nc,
+                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float v = GGML_FP16_TO_FP32(x);
+            GGML_UNUSED(v);
+            assert(!isnan(v));
+            assert(!isinf(v));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu_erf(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_erf_f32(params, dst);
+            } break;
+        case GGML_TYPE_F16:
+            {
+                ggml_compute_forward_gelu_erf_f16(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_gelu_quick
 
 static void ggml_compute_forward_gelu_quick_f32(
@@ -7749,6 +7852,10 @@ void ggml_compute_forward_unary(
             {
                 ggml_compute_forward_gelu(params, dst);
             } break;
+        case GGML_UNARY_OP_GELU_ERF:
+            {
+                ggml_compute_forward_gelu_erf(params, dst);
+            } break;
         case GGML_UNARY_OP_GELU_QUICK:
             {
                 ggml_compute_forward_gelu_quick(params, dst);
index 23cbb3051f2c85ac08adf7d71da6c0ce5be3f6db..c77349ebe410c9836ce75acb9a221cd4781a0d1b 100644 (file)
@@ -428,6 +428,7 @@ inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp
 static const float GELU_COEF_A     = 0.044715f;
 static const float GELU_QUICK_COEF = -1.702f;
 static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+static const float SQRT_2_INV      = 0.70710678118654752440084436210484f;
 
 inline static float ggml_gelu_f32(float x) {
     return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -440,6 +441,14 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
     }
 }
 
+inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    for (int i = 0; i < n; ++i) {
+        float xi = GGML_FP16_TO_FP32(x[i]);
+        float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
+        y[i] = GGML_FP32_TO_FP16(res);
+    }
+}
+
 #ifdef GGML_GELU_FP16
 inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
     uint16_t t;
@@ -463,6 +472,13 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 }
 #endif
 
+inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        float xi = x[i];
+        y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
+    }
+}
+
 inline static float ggml_gelu_quick_f32(float x) {
     return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
 }
index 85dbbcd5d7f99d91066037905aa9d81ab78c2c2d..f78e7eee553b65dd90fa21b52f48f516192157e8 100644 (file)
@@ -149,6 +149,8 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SIGMOID,
     GGML_METAL_KERNEL_TYPE_GELU,
     GGML_METAL_KERNEL_TYPE_GELU_4,
+    GGML_METAL_KERNEL_TYPE_GELU_ERF,
+    GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
     GGML_METAL_KERNEL_TYPE_GELU_QUICK,
     GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
     GGML_METAL_KERNEL_TYPE_SILU,
@@ -1103,6 +1105,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                         sigmoid,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                            gelu,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4,                          gelu_4,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF,                        gelu_erf,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4,                      gelu_erf_4,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                      gelu_quick,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                    gelu_quick_4,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                            silu,                            true);
@@ -1613,6 +1617,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_SIGMOID:
                 case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_ERF:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
@@ -2251,6 +2256,25 @@ static bool ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+                case GGML_UNARY_OP_GELU_ERF:
+                {
+                    int64_t n = ggml_nelements(dst);
+
+                    id<MTLComputePipelineState> pipeline = nil;
+
+                    if (n % 4 == 0) {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
+                        n /= 4;
+                    } else {
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
+                    }
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 case GGML_UNARY_OP_GELU_QUICK:
                 {
                     int64_t n = ggml_nelements(dst);
index f18473dcb482a6c5e803cec8cb8bdea2978b098f..59899550ed38cc64003c4a540ff1ba0033cb530c 100644 (file)
@@ -856,6 +856,7 @@ kernel void kernel_tanh(
 constant float GELU_COEF_A     = 0.044715f;
 constant float GELU_QUICK_COEF = -1.702f;
 constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f;
+constant float SQRT_2_INV      = 0.70710678118654752440084436210484f;
 
 kernel void kernel_gelu(
     device const float * src0,
@@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4(
     dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
 }
 
+// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+// ref: https://www.johndcook.com/blog/python_erf/
+constant float p_erf  = 0.3275911f;
+constant float a1_erf = 0.254829592f;
+constant float a2_erf = -0.284496736f;
+constant float a3_erf = 1.421413741f;
+constant float a4_erf = -1.453152027f;
+constant float a5_erf = 1.061405429f;
+
+template<typename T>
+T erf_approx(T x) {
+    T sign_x = sign(x);
+    x = fabs(x);
+    T t = 1.0f / (1.0f + p_erf * x);
+    T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+    return sign_x * y;
+}
+
+kernel void kernel_gelu_erf(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float & x = src0[tpig];
+
+    dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
+}
+
+kernel void kernel_gelu_erf_4(
+    device const float4 * src0,
+    device       float4 * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    device const float4 & x = src0[tpig];
+
+    dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
+}
+
 kernel void kernel_silu(
         device const float * src0,
         device       float * dst,
index d48adb9afb824a9d55cba8ad4f65c15d03bb8b00..57d3e39adf7581d4f413d9f54b9ee4b0e06b6a7a 100644 (file)
@@ -1099,9 +1099,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "HARDSWISH",
     "HARDSIGMOID",
     "EXP",
+    "GELU_ERF",
 };
 
-static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
+static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
 
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2501,6 +2502,20 @@ struct ggml_tensor * ggml_gelu_inplace(
     return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
 }
 
+// ggml_gelu_erf
+
+struct ggml_tensor * ggml_gelu_erf(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
+}
+
+struct ggml_tensor * ggml_gelu_erf_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
+}
+
 // ggml_gelu_quick
 
 struct ggml_tensor * ggml_gelu_quick(