]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: support GGML_UNARY_OP_XIELU (llama/18062)
authorJeff Bolz <redacted>
Sun, 21 Dec 2025 09:17:58 +0000 (03:17 -0600)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 10:39:43 +0000 (12:39 +0200)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/generic_head.glsl
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
src/ggml-vulkan/vulkan-shaders/xielu.comp [new file with mode: 0644]

index 71939ed35e108bed8b8bfbda51d5182a783e6f71..1c614536362d29dc9c24ed66e6f4911ea67c5c87 100644 (file)
@@ -689,6 +689,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_gelu_quick[2];
     vk_pipeline pipeline_silu[2];
     vk_pipeline pipeline_relu[2];
+    vk_pipeline pipeline_xielu[2];
     vk_pipeline pipeline_neg[2];
     vk_pipeline pipeline_tanh[2];
     vk_pipeline pipeline_sigmoid[2];
@@ -990,6 +991,8 @@ struct vk_op_push_constants {
     uint32_t KY;
     float param1;
     float param2;
+    float param3;
+    float param4;
 };
 
 struct vk_op_glu_push_constants {
@@ -3973,6 +3976,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     CREATE_UNARY(gelu_quick)
     CREATE_UNARY(silu)
     CREATE_UNARY(relu)
+    CREATE_UNARY(xielu)
     CREATE_UNARY(neg)
     CREATE_UNARY(tanh)
     CREATE_UNARY(sigmoid)
@@ -8549,6 +8553,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_RELU:
                 return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
+            case GGML_UNARY_OP_XIELU:
+                return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_NEG:
                 return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_TANH:
@@ -9695,14 +9701,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
 
     ggml_vk_op_f32_opt_step_adamw(
         ctx, subctx, dst,
-        { (uint32_t)n, 0, 0.0f, 0.0f }
+        { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
     );
 }
 
 static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
     const size_t n = ggml_nelements(dst->src[0]);
 
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9788,6 +9794,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg
         1,
         ggml_get_op_params_f32(dst, 0),
         ggml_get_op_params_f32(dst, 2),
+        0.0f, 0.0f,
     };
 
     vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
@@ -9809,6 +9816,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml
         1,
         ggml_get_op_params_f32(dst, 0),
         0.0f,
+        0.0f, 0.0f,
     };
 
     vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
@@ -9924,13 +9932,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
 }
 
 static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
 
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -9941,7 +9949,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
     const float eps = float_op_params[1];
     const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
 
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
 }
 
 static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
@@ -10110,16 +10118,26 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
 
 static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
+}
+
+static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    float * op_params = (float *)dst->op_params;
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
+        {
+            (uint32_t)ggml_nelements(src0), 0,
+            op_params[1], op_params[2], op_params[3], op_params[4]
+        }
+    );
 }
 
 static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10244,7 +10262,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
 
 static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
 }
 
 static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
@@ -10541,11 +10559,11 @@ static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, co
 }
 
 static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
 }
 
 static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -10804,7 +10822,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
 
 static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     const float * op_params = (const float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f });
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
 }
 
 #ifdef GGML_VULKAN_RUN_TESTS
@@ -12050,6 +12068,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_UNARY_OP_TRUNC:
             ggml_vk_unary(ctx, compute_ctx, src0, node);
             break;
+        case GGML_UNARY_OP_XIELU:
+            ggml_vk_xielu(ctx, compute_ctx, src0, node);
+            break;
         default:
             return false;
         }
@@ -13843,6 +13864,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_XIELU:
                 case GGML_UNARY_OP_NEG:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_SIGMOID:
@@ -14748,7 +14770,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
         } else if (tensor->op == GGML_OP_LOG) {
             tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_TRI) {
-            tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
+            tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
         } else if (tensor->op == GGML_OP_DIAG) {
             tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_CLAMP) {
@@ -14836,6 +14858,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             case GGML_UNARY_OP_RELU:
                 tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
                 break;
+            case GGML_UNARY_OP_XIELU:
+                tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
+                ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
+                ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
+                ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
+                ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
+                break;
             case GGML_UNARY_OP_NEG:
                 tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
                 break;
index 66e46ae6796b80803ab5f4031b121109ef802793..3797901f0435e19181bd2b17f562776c20f24616 100644 (file)
@@ -6,4 +6,6 @@ layout (push_constant) uniform parameter
     uint KY;
     float param1;
     float param2;
+    float param3;
+    float param4;
 } p;
index b0ade078c7baf4cc8c84dd86f0605367f421598c..92ad3bcab1e58d7e0d6cb4c26cfdde9c5e64151f 100644 (file)
@@ -853,6 +853,8 @@ void process_shaders() {
     string_to_spv("hardswish_f32",  "hardswish.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("abs_f16",        "abs.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("abs_f32",        "abs.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("xielu_f16",      "xielu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("xielu_f32",      "xielu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 
     string_to_spv("tri_f16",        "tri.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("tri_f32",        "tri.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
diff --git a/src/ggml-vulkan/vulkan-shaders/xielu.comp b/src/ggml-vulkan/vulkan-shaders/xielu.comp
new file mode 100644 (file)
index 0000000..35d463b
--- /dev/null
@@ -0,0 +1,35 @@
+#version 450
+
+#include "generic_head.glsl"
+#include "types.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    float x = float(data_a[i]);
+
+    float alpha_n = p.param1;
+    float alpha_p = p.param2;
+    float beta = p.param3;
+    float eps = p.param4;
+
+    if (x > 0.0f) {
+        x = alpha_p * x * x + beta * x;
+    } else {
+        const float min_x_eps = min(x, eps);
+        x = (exp(min_x_eps) - 1 - x) * alpha_n + beta * x;
+    }
+
+    data_d[i] = D_TYPE(x);
+}