]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
add GELU_ERF (llama/14455)
authorSigbjørn Skjæret <redacted>
Tue, 1 Jul 2025 08:14:21 +0000 (10:14 +0200)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/gelu_erf.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 110348bf065f2ad79238c4a5e6a686795f8d2bea..fae7851a616d11b7c9d8f0f4fdf227cd0f0797ff 100644 (file)
@@ -432,6 +432,7 @@ struct vk_device_struct {
 
     // [src/dst 0=fp32,1=fp16]
     vk_pipeline pipeline_gelu[2];
+    vk_pipeline pipeline_gelu_erf[2];
     vk_pipeline pipeline_gelu_quick[2];
     vk_pipeline pipeline_silu[2];
     vk_pipeline pipeline_relu[2];
@@ -2798,6 +2799,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
     CREATE_UNARY(gelu)
+    CREATE_UNARY(gelu_erf)
     CREATE_UNARY(gelu_quick)
     CREATE_UNARY(silu)
     CREATE_UNARY(relu)
@@ -6531,6 +6533,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_GELU:
                 return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
+            case GGML_UNARY_OP_GELU_ERF:
+                return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_GELU_QUICK:
                 return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_RELU:
@@ -8822,6 +8826,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         switch (ggml_get_unary_op(node)) {
         case GGML_UNARY_OP_SILU:
         case GGML_UNARY_OP_GELU:
+        case GGML_UNARY_OP_GELU_ERF:
         case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
         case GGML_UNARY_OP_TANH:
@@ -9072,6 +9077,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         switch (ggml_get_unary_op(node)) {
         case GGML_UNARY_OP_SILU:
         case GGML_UNARY_OP_GELU:
+        case GGML_UNARY_OP_GELU_ERF:
         case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
         case GGML_UNARY_OP_TANH:
@@ -9290,6 +9296,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
         switch (ggml_get_unary_op(tensor)) {
         case GGML_UNARY_OP_SILU:
         case GGML_UNARY_OP_GELU:
+        case GGML_UNARY_OP_GELU_ERF:
         case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
         case GGML_UNARY_OP_TANH:
@@ -10096,6 +10103,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_GELU_ERF:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
@@ -10836,6 +10844,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         case GGML_UNARY_OP_GELU:
             tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
             break;
+        case GGML_UNARY_OP_GELU_ERF:
+            tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
+            break;
         case GGML_UNARY_OP_GELU_QUICK:
             tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
             break;
diff --git a/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp b/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp
new file mode 100644 (file)
index 0000000..5fd5a5e
--- /dev/null
@@ -0,0 +1,39 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
+    // ref: https://www.johndcook.com/blog/python_erf/
+    const float p_erf  = 0.3275911f;
+    const float a1_erf = 0.254829592f;
+    const float a2_erf = -0.284496736f;
+    const float a3_erf = 1.421413741f;
+    const float a4_erf = -1.453152027f;
+    const float a5_erf = 1.061405429f;
+
+    const float SQRT_2_INV = 0.70710678118654752440084436210484f;
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    const float a = float(data_a[i]);
+    const float a_div_sqr2 = a * SQRT_2_INV;
+    const float sign_x = sign(a_div_sqr2);
+    const float x = abs(a_div_sqr2);
+    const float t = 1.0f / (1.0f + p_erf * x);
+    const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
+    const float erf_approx = sign_x * y;
+
+    data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx));
+}
index 5b527b104cc9312aeb5aa613a5e0a8989897289c..4658ad0c9c26f58ac5b3553e52fa1812643e76cf 100644 (file)
@@ -574,6 +574,8 @@ void process_shaders() {
 
     string_to_spv("gelu_f16",       "gelu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("gelu_f32",       "gelu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("gelu_erf_f16",   "gelu_erf.comp",    {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("gelu_erf_f32",   "gelu_erf.comp",    {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("gelu_quick_f16", "gelu_quick.comp",  {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("gelu_quick_f32", "gelu_quick.comp",  {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("silu_f16",       "silu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});