vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
+ vk_pipeline pipeline_neg[2];
vk_pipeline pipeline_tanh[2];
vk_pipeline pipeline_sigmoid[2];
vk_pipeline pipeline_hardsigmoid[2];
vk_pipeline pipeline_hardswish[2];
+ vk_pipeline pipeline_abs[2];
vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2];
CREATE_UNARY(gelu_quick)
CREATE_UNARY(silu)
CREATE_UNARY(relu)
+ CREATE_UNARY(neg)
CREATE_UNARY(tanh)
CREATE_UNARY(sigmoid)
CREATE_UNARY(hardsigmoid)
CREATE_UNARY(hardswish)
+ CREATE_UNARY(abs)
#undef CREATE_UNARY
#define CREATE_UNARY_RTE(name) \
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_RELU:
return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
+ case GGML_UNARY_OP_NEG:
+ return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_TANH:
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_SIGMOID:
return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_HARDSWISH:
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
+ case GGML_UNARY_OP_ABS:
+ return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
default:
break;
}
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_ABS:
break;
default:
return false;
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_ABS:
ggml_vk_unary(ctx, compute_ctx, src0, node);
break;
default:
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_ABS:
buf = tensor->buffer;
break;
default:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_ABS:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
case GGML_UNARY_OP_RELU:
tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
break;
+ case GGML_UNARY_OP_NEG:
+ tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
+ break;
case GGML_UNARY_OP_TANH:
tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_HARDSWISH:
tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
break;
+ case GGML_UNARY_OP_ABS:
+ tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
+ break;
default:
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("neg_f16", "neg.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("neg_f32", "neg.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
+ string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : "";