]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kompute : implement op_getrows_f32 (#6403)
authorwoachk <redacted>
Mon, 3 Jun 2024 05:32:16 +0000 (07:32 +0200)
committerGitHub <redacted>
Mon, 3 Jun 2024 05:32:16 +0000 (08:32 +0300)
op_getrows_f32 is required since https://github.com/ggerganov/llama.cpp/pull/6122
for the Vulkan w/ Kompute backend to be functional.

As such, implement this op to make this backend functional again.

CMakeLists.txt
ggml-kompute.cpp
kompute-shaders/op_getrows_f32.comp [new file with mode: 0644]

index 52b392a13ce5ee21ae8e3427c61e941599284b54..a9b33eaa1f60270a705d00087234d8e5fc6db4d7 100644 (file)
@@ -777,6 +777,7 @@ if (LLAMA_KOMPUTE)
             kompute-shaders/op_mul_mat_q4_0.comp
             kompute-shaders/op_mul_mat_q4_1.comp
             kompute-shaders/op_mul_mat_q6_k.comp
+            kompute-shaders/op_getrows_f32.comp
             kompute-shaders/op_getrows_f16.comp
             kompute-shaders/op_getrows_q4_0.comp
             kompute-shaders/op_getrows_q4_1.comp
@@ -809,6 +810,7 @@ if (LLAMA_KOMPUTE)
             shaderop_mul_mat_q4_0.h
             shaderop_mul_mat_q4_1.h
             shaderop_mul_mat_q6_k.h
+            shaderop_getrows_f32.h
             shaderop_getrows_f16.h
             shaderop_getrows_q4_0.h
             shaderop_getrows_q4_1.h
index 0c51c322f8df1eb10f480adbf29e5dd13158e7fe..eabd70d5eeed86eb2d465bee672eabf015a54c0d 100644 (file)
@@ -22,6 +22,7 @@
 #include "shaderop_mul_mat_q4_1.h"
 #include "shaderop_mul_mat_q6_k.h"
 #include "shaderop_mul_mat_mat_f32.h"
+#include "shaderop_getrows_f32.h"
 #include "shaderop_getrows_f16.h"
 #include "shaderop_getrows_q4_0.h"
 #include "shaderop_getrows_q4_1.h"
@@ -1146,6 +1147,14 @@ static void ggml_vk_get_rows(
     seq.record<kp::OpAlgoDispatch>(s_algo);
 }
 
+template <typename... Args>
+static void ggml_vk_get_rows_f32(Args&&... args) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
+        kp::shader_data::op_getrows_f32_comp_spv_len);
+
+    ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
+}
+
 template <typename... Args>
 static void ggml_vk_get_rows_f16(Args&&... args) {
     const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
@@ -1371,6 +1380,7 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
             return op->ne[3] == 1;
         case GGML_OP_GET_ROWS:
             switch (op->src[0]->type) {
+                case GGML_TYPE_F32:
                 case GGML_TYPE_F16:
                 case GGML_TYPE_Q4_0:
                 case GGML_TYPE_Q4_1:
@@ -1661,7 +1671,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                     } break;
                 case GGML_OP_GET_ROWS:
                     {
-                        if (src0t == GGML_TYPE_F16) {
+                        if (src0t == GGML_TYPE_F32) {
+                            ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+                        } else if (src0t == GGML_TYPE_F16) {
                             ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
                         } else if (src0t == GGML_TYPE_Q4_0) {
                             ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
diff --git a/kompute-shaders/op_getrows_f32.comp b/kompute-shaders/op_getrows_f32.comp
new file mode 100644 (file)
index 0000000..9d7acda
--- /dev/null
@@ -0,0 +1,31 @@
+#version 450
+
+#include "common.comp"
+
+layout(local_size_x = 1) in;
+
+layout (binding = 0) readonly buffer tensorInA { float inA[]; };
+layout (binding = 1) readonly buffer tensorInB { int inB[]; };
+layout (binding = 2) writeonly buffer tensorOut { float out_[]; };
+
+layout (push_constant) uniform parameter {
+    uint inAOff;
+    uint inBOff;
+    uint outOff;
+    int ne00;
+    int nb01;
+    int nb1;
+} pcs;
+
+void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) {
+    for (int j = 0; j < k; j++) {
+        out_[y + j] = inA[x + j];
+    }
+}
+
+void main() {
+    const uint i = gl_WorkGroupID.x;
+    const int r = inB[i + pcs.inBOff];
+
+    dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00);
+}