]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
SYCL: add gelu_erf kernel (llama/13749)
authorAkarshan Biswas <redacted>
Tue, 27 May 2025 15:22:59 +0000 (20:52 +0530)
committerGeorgi Gerganov <redacted>
Sun, 1 Jun 2025 11:01:05 +0000 (14:01 +0300)
* SYCL: add gelu_erf kernel

* refactor code

Co-authored-by: Atharva Dubey <redacted>
* Use scope_op_debug_print

---------

Co-authored-by: Atharva Dubey <redacted>
src/ggml-sycl/element_wise.cpp
src/ggml-sycl/element_wise.hpp
src/ggml-sycl/ggml-sycl.cpp

index fd3cfb573e29ca1f6007c14a6433a0e75e313710..5b7c4f0b4f003c55c8fadf4e37386cebb87074fa 100644 (file)
@@ -84,6 +84,15 @@ static void gelu_quick(const T *x, T *dst, int k,
     dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
 }
 
+template<typename T>
+static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
+    const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
+    for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
+       auto x_i = x[i];
+        dst[i] = static_cast<T>(0.5f) * x_i * (static_cast<T>(1.0f) + sycl::erf(x_i * SQRT_2_INV));
+    }
+}
+
 template<typename T>
 static void tanh(const T *x, T *dst, int k,
                      const sycl::nd_item<3> &item_ct1) {
@@ -400,6 +409,20 @@ static void gelu_quick_sycl(const T *x, T *dst, const int k,
         });
 }
 
+
+template<typename T>
+static void gelu_erf_sycl(const T *x, T *dst, const int k,
+                                queue_ptr stream) {
+    const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            gelu_erf(x, dst, k, item_ct1);
+        });
+}
+
 template<typename T>
 static void tanh_sycl(const T *x, T *dst, const int k,
                           queue_ptr stream) {
@@ -816,6 +839,38 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
     }
 }
 
+inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+#if defined (GGML_SYCL_F16)
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+#else
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+#endif
+    GGML_ASSERT(dst->src[0]->type == dst->type);
+    dpct::queue_ptr main_stream = ctx.stream();
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+    switch (dst->type) {
+#if defined (GGML_SYCL_F16)
+        case GGML_TYPE_F16:
+            {
+                auto data_pts = cast_data<sycl::half>(dst);
+                gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                break;
+            }
+#endif
+        case GGML_TYPE_F32:
+            {
+                auto data_pts = cast_data<float>(dst);
+                gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
+                break;
+            }
+        default:
+            GGML_ABORT("GGML tensor type not supported!\n");
+    }
+}
+
+
 inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 #if defined (GGML_SYCL_F16)
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
@@ -1425,6 +1480,11 @@ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     ggml_sycl_op_gelu_quick(ctx, dst);
 }
 
+void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
+    ggml_sycl_op_gelu_erf(ctx, dst);
+}
+
 void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
     ggml_sycl_op_tanh(ctx, dst);
index f4199d69da694dfdd7a8e8f57a0ade0415c0fb4b..bd40113f0970560317caa4d4fc4bc506c5485495 100644 (file)
@@ -38,6 +38,8 @@ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
+void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
 void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
index 6a53bd12c4efb6ccfdd76529a444c5abbc596231..e96e1f248845183a7eea0e263f710bd4ba4ce15f 100644 (file)
@@ -3543,6 +3543,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
                 case GGML_UNARY_OP_GELU_QUICK:
                     ggml_sycl_gelu_quick(ctx, dst);
                     break;
+                case GGML_UNARY_OP_GELU_ERF:
+                    ggml_sycl_gelu_erf(ctx, dst);
+                    break;
                 case GGML_UNARY_OP_TANH:
                     ggml_sycl_tanh(ctx, dst);
                     break;
@@ -4096,6 +4099,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_UNARY_OP_HARDSIGMOID:
                 case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_GELU_ERF:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_EXP:
                 case GGML_UNARY_OP_SGN: