]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
kompute: add mul_mat_q4_k shader (llama/10097)
authorSergio López <redacted>
Thu, 31 Oct 2024 09:09:52 +0000 (10:09 +0100)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
This is a more or less direct translation from the Metal implementation
to GLSL.

Signed-off-by: Sergio Lopez <redacted>
ggml/src/CMakeLists.txt
ggml/src/ggml-kompute.cpp

index 729f61d737281f19dc51873adac4e67260fb1ab7..153cc8dcd9fd9f4ddd5f1857f730d3e48d86e1d5 100644 (file)
@@ -800,6 +800,7 @@ if (GGML_KOMPUTE)
             kompute-shaders/op_mul_mat_q8_0.comp
             kompute-shaders/op_mul_mat_q4_0.comp
             kompute-shaders/op_mul_mat_q4_1.comp
+            kompute-shaders/op_mul_mat_q4_k.comp
             kompute-shaders/op_mul_mat_q6_k.comp
             kompute-shaders/op_getrows_f32.comp
             kompute-shaders/op_getrows_f16.comp
@@ -833,6 +834,7 @@ if (GGML_KOMPUTE)
             shaderop_mul_mat_q8_0.h
             shaderop_mul_mat_q4_0.h
             shaderop_mul_mat_q4_1.h
+            shaderop_mul_mat_q4_k.h
             shaderop_mul_mat_q6_k.h
             shaderop_getrows_f32.h
             shaderop_getrows_f16.h
index fea69fb0477d531a63c705c246c2d230c4e42452..2fea9e4cc8d38afcea9b8d71a3cdba6c9b626154 100644 (file)
@@ -20,6 +20,7 @@
 #include "shaderop_mul_mat_q8_0.h"
 #include "shaderop_mul_mat_q4_0.h"
 #include "shaderop_mul_mat_q4_1.h"
+#include "shaderop_mul_mat_q4_k.h"
 #include "shaderop_mul_mat_q6_k.h"
 #include "shaderop_mul_mat_mat_f32.h"
 #include "shaderop_getrows_f32.h"
@@ -1067,6 +1068,40 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
     ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
 }
 
+static void ggml_vk_mul_mat_q4_k(
+    kp::Sequence& seq,
+    const std::shared_ptr<kp::Tensor>& inA,
+    const std::shared_ptr<kp::Tensor>& inB,
+    const std::shared_ptr<kp::Tensor>& out,
+    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
+    int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
+    int32_t ne1, int32_t r2, int32_t r3
+) {
+    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
+        kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
+
+    struct PushConstants {
+        uint32_t inAOff, inBOff, outOff;
+        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
+    } pushConsts {
+        0, 0, 0,
+        ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
+    };
+
+    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+    if (!komputeManager()->hasAlgorithm(__func__)) {
+        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
+    } else {
+        s_algo = komputeManager()->getAlgorithm(__func__);
+        s_algo->setTensors({inA, inB, out});
+        s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
+        s_algo->setPushConstants<PushConstants>({pushConsts});
+        s_algo->updateDescriptors(s_kompute_context->pool.get());
+    }
+    seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
 static void ggml_vk_mul_mat_q6_k(
     kp::Sequence& seq,
     const std::shared_ptr<kp::Tensor>& inA,
@@ -1384,6 +1419,7 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
                 case GGML_TYPE_Q8_0:
                 case GGML_TYPE_Q4_0:
                 case GGML_TYPE_Q4_1:
+                case GGML_TYPE_Q4_K:
                     return true;
                 default:
                     ;
@@ -1635,6 +1671,12 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                                     ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
                                 );
                                 break;
+                            case GGML_TYPE_Q4_K:
+                                ggml_vk_mul_mat_q4_k(
+                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
+                                );
+                                break;
                             case GGML_TYPE_Q6_K:
                                 ggml_vk_mul_mat_q6_k(
                                     seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,