]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add quick GELU (#254)
authorM. Yusuf Sarıgöz <redacted>
Fri, 16 Jun 2023 17:36:46 +0000 (20:36 +0300)
committerGitHub <redacted>
Fri, 16 Jun 2023 17:36:46 +0000 (20:36 +0300)
* Implement Quick GELU

* Revert "Implement Quick GELU"

This reverts commit ff220cc1f91a184f195d19b17ed4c352cc72a6f0.

* Tidy up ggml.h

* Respect to the style of ggml

* Fix: Fix minor typo

* Rename `quick_gelu` -> `gelu_quick`

include/ggml/ggml.h
src/ggml.c

index f3df06c95686a08cddeb1ad4516bb1e582957157..1e16900bc6410b68d3e1ad15500bf21c46ce801d 100644 (file)
@@ -290,6 +290,7 @@ extern "C" {
         GGML_OP_STEP,
         GGML_OP_RELU,
         GGML_OP_GELU,
+        GGML_OP_GELU_QUICK,
         GGML_OP_SILU,
         GGML_OP_SILU_BACK,
         GGML_OP_NORM, // normalize
@@ -687,6 +688,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_gelu_quick(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_silu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
index c485733fcf7ef42e2f85b2839b5c485af5014b63..a3f116cc4518278b9eba3ea6a7825311cbf5b515 100644 (file)
@@ -98,6 +98,7 @@ typedef void* thread_ret_t;
 /*#define GGML_PERF*/
 #define GGML_DEBUG 0
 #define GGML_GELU_FP16
+#define GGML_GELU_QUICK_FP16
 #define GGML_SILU_FP16
 
 #define GGML_SOFT_MAX_UNROLL 4
@@ -322,6 +323,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
 // precomputed gelu table for f16 (128 KB)
 static ggml_fp16_t table_gelu_f16[1 << 16];
 
+// precomputed quick gelu table for f16 (128 KB)
+static ggml_fp16_t table_gelu_quick_f16[1 << 16];
+
 // precomputed silu table for f16 (128 KB)
 static ggml_fp16_t table_silu_f16[1 << 16];
 
@@ -3288,6 +3292,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
 static const float GELU_COEF_A    = 0.044715f;
+static const float GELU_QUICK_COEF    = -1.702f;
 static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
 inline static float ggml_gelu_f32(float x) {
@@ -3318,6 +3323,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 }
 #endif
 
+inline static float ggml_gelu_quick_f32(float x) {
+    return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
+}
+
+inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
+    const uint16_t * i16 = (const uint16_t *) x;
+    for (int i = 0; i < n; ++i) {
+        y[i] = table_gelu_quick_f16[i16[i]];
+    }
+}
+
+#ifdef GGML_GELU_QUICK_FP16
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    uint16_t t;
+    for (int i = 0; i < n; ++i) {
+        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
+        memcpy(&t, &fp16, sizeof(uint16_t));
+        y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
+    }
+}
+#else
+inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
+    for (int i = 0; i < n; ++i) {
+        y[i] = ggml_gelu_quick_f32(x[i]);
+    }
+}
+#endif
+
 // Sigmoid Linear Unit (SiLU) function
 inline static float ggml_silu_f32(float x) {
     return x/(1.0f + expf(-x));
@@ -3519,6 +3552,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "STEP",
     "RELU",
     "GELU",
+    "GELU_QUICK",
     "SILU",
     "SILU_BACK",
     "NORM",
@@ -3558,7 +3592,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "MAP_BINARY",
 };
 
-static_assert(GGML_OP_COUNT == 54, "GGML_OP_COUNT != 54");
+static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3583,6 +3617,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "step(x)",
     "relu(x)",
     "gelu(x)",
+    "gelu_quick(x)",
     "silu(x)",
     "silu_back(x)",
     "norm(x)",
@@ -3622,7 +3657,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "f(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 54, "GGML_OP_COUNT != 54");
+static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3899,7 +3934,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         // initialize time system (required on Windows)
         ggml_time_init();
 
-        // initialize GELU, SILU and EXP F32 tables
+        // initialize GELU, Quick GELU, SILU and EXP F32 tables
         {
             const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
 
@@ -3909,13 +3944,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
                 memcpy(&ii, &ui, sizeof(ii));
                 const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
                 table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
+                table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
                 table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
                 table_exp_f16[i]  = GGML_FP32_TO_FP16(expf(f));
             }
 
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 
-            GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
+            GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
         }
 
         // initialize g_state
@@ -5271,6 +5307,40 @@ struct ggml_tensor * ggml_gelu_inplace(
     return ggml_gelu_impl(ctx, a, true);
 }
 
+// ggml_gelu_quick
+
+struct ggml_tensor * ggml_gelu_quick_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_GELU_QUICK;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = NULL;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_gelu_quick(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_gelu_quick_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_gelu_quick_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_gelu_quick_impl(ctx, a, true);
+}
+
 // ggml_silu
 
 struct ggml_tensor * ggml_silu_impl(
@@ -9118,6 +9188,67 @@ static void ggml_compute_forward_gelu(
     //printf("XXXXXXXX gelu\n");
 }
 
+// ggml_compute_forward_gelu_quick
+
+static void ggml_compute_forward_gelu_quick_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src0->ne[0];
+    const int nr = ggml_nrows(src0);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        ggml_vec_gelu_quick_f32(nc,
+                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
+                (float *) ((char *) src0->data + i1*(src0->nb[1])));
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_gelu_quick(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gelu_quick_f32(params, src0, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+
+    //printf("XXXXXXXX quick gelu\n");
+}
+
 // ggml_compute_forward_silu
 
 static void ggml_compute_forward_silu_f32(
@@ -13360,6 +13491,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_gelu(params, tensor->src0, tensor);
             } break;
+        case GGML_OP_GELU_QUICK:
+            {
+                ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
+            } break;
         case GGML_OP_SILU:
             {
                 ggml_compute_forward_silu(params, tensor->src0, tensor);
@@ -13771,6 +13906,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_GELU_QUICK:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_ALIBI:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -14578,6 +14717,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     } break;
                 case GGML_OP_MUL:
                 case GGML_OP_GELU:
+                case GGML_OP_GELU_QUICK:
                 case GGML_OP_SILU:
                 case GGML_OP_SILU_BACK:
                 case GGML_OP_NORM: