From: woachk Date: Mon, 3 Jun 2024 05:32:16 +0000 (+0200) Subject: kompute : implement op_getrows_f32 (llama/6403) X-Git-Tag: upstream/0.0.1642~616 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=5ed1871585e461aece960323f0ba8044e6d54183;p=pkg%2Fggml%2Fsources%2Fggml kompute : implement op_getrows_f32 (llama/6403) op_getrows_f32 is required since https://github.com/ggerganov/llama.cpp/pull/6122 for the Vulkan w/ Kompute backend to be functional. As such, implement this op to make this backend functional again. --- diff --git a/src/ggml-kompute.cpp b/src/ggml-kompute.cpp index 0c51c322..eabd70d5 100644 --- a/src/ggml-kompute.cpp +++ b/src/ggml-kompute.cpp @@ -22,6 +22,7 @@ #include "shaderop_mul_mat_q4_1.h" #include "shaderop_mul_mat_q6_k.h" #include "shaderop_mul_mat_mat_f32.h" +#include "shaderop_getrows_f32.h" #include "shaderop_getrows_f16.h" #include "shaderop_getrows_q4_0.h" #include "shaderop_getrows_q4_1.h" @@ -1146,6 +1147,14 @@ static void ggml_vk_get_rows( seq.record(s_algo); } +template +static void ggml_vk_get_rows_f32(Args&&... args) { + const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv, + kp::shader_data::op_getrows_f32_comp_spv_len); + + ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward(args)...); +} + template static void ggml_vk_get_rows_f16(Args&&... args) { const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv, @@ -1371,6 +1380,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) { return op->ne[3] == 1; case GGML_OP_GET_ROWS: switch (op->src[0]->type) { + case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -1661,7 +1671,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml } break; case GGML_OP_GET_ROWS: { - if (src0t == GGML_TYPE_F16) { + if (src0t == GGML_TYPE_F32) { + ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); + } else if (src0t == GGML_TYPE_F16) { ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1)); } else if (src0t == GGML_TYPE_Q4_0) { ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));