// [src/dst 0=fp32,1=fp16]
vk_pipeline pipeline_exp[2];
+ vk_pipeline pipeline_elu[2];
vk_pipeline pipeline_gelu[2];
vk_pipeline pipeline_gelu_erf[2];
vk_pipeline pipeline_gelu_quick[2];
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+ CREATE_UNARY(elu)
CREATE_UNARY(gelu)
CREATE_UNARY(gelu_erf)
CREATE_UNARY(gelu_quick)
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_EXP:
return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
+ case GGML_UNARY_OP_ELU:
+ return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_SILU:
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_GELU:
}
switch (ggml_get_unary_op(node)) {
+ case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_GELU:
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_EXP:
+ case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_EXP:
tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
break;
+ case GGML_UNARY_OP_ELU:
+ tensor_clone = ggml_elu(ggml_ctx, src_clone[0]);
+ break;
case GGML_UNARY_OP_SILU:
tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
break;
--- /dev/null
+#version 450
+
+#include "generic_head.glsl"
+#include "types.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ float x = float(data_a[i]);
+
+ if (x < 0.0f) {
+ x = exp(x) - 1;
+ }
+
+ data_d[i] = D_TYPE(x);
+}
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("elu_f16", "elu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});