From: GiantPrince Date: Sun, 8 Mar 2026 11:38:17 +0000 (-0400) Subject: ggml-vulkan: Add ELU op support (llama/20183) X-Git-Tag: upstream/1.8.4~68 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=8d97f59639e75de2ea885bd27df51d1f48cc4dc1;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp ggml-vulkan: Add ELU op support (llama/20183) * ggml-Vulkan: add ELU support * ggml-Vulkan: remove extra spaces and variables * ggml-Vulkan: fix format issue * ggml-Vulkan: fix format issue * fix whitespace issue * Update Vulkan.csv and ops.md --- diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 23d6d39e..0bf7d2e2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -744,6 +744,7 @@ struct vk_device_struct { // [src/dst 0=fp32,1=fp16] vk_pipeline pipeline_exp[2]; + vk_pipeline pipeline_elu[2]; vk_pipeline pipeline_gelu[2]; vk_pipeline pipeline_gelu_erf[2]; vk_pipeline pipeline_gelu_quick[2]; @@ -4373,6 +4374,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + CREATE_UNARY(elu) CREATE_UNARY(gelu) CREATE_UNARY(gelu_erf) CREATE_UNARY(gelu_quick) @@ -9241,6 +9243,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_EXP: return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_ELU: + return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SILU: return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: @@ -12852,6 +12856,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: @@ -14951,6 +14956,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: @@ -16074,6 +16080,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_EXP: tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_ELU: + tensor_clone = ggml_elu(ggml_ctx, src_clone[0]); + break; case GGML_UNARY_OP_SILU: tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); break; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp new file mode 100644 index 00000000..84dcbd8c --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + float x = float(data_a[i]); + + if (x < 0.0f) { + x = exp(x) - 1; + } + + data_d[i] = D_TYPE(x); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 85455988..ed077dfb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -867,6 +867,8 @@ void process_shaders() { string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("elu_f16", "elu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});