#include "shaderop_mul_mat_q8_0.h"
#include "shaderop_mul_mat_q4_0.h"
#include "shaderop_mul_mat_q4_1.h"
+#include "shaderop_mul_mat_q4_k.h"
#include "shaderop_mul_mat_q6_k.h"
#include "shaderop_mul_mat_mat_f32.h"
#include "shaderop_getrows_f32.h"
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
}
+static void ggml_vk_mul_mat_q4_k(
+ kp::Sequence& seq,
+ const std::shared_ptr<kp::Tensor>& inA,
+ const std::shared_ptr<kp::Tensor>& inB,
+ const std::shared_ptr<kp::Tensor>& out,
+ uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
+ int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
+ int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
+ int32_t ne1, int32_t r2, int32_t r3
+) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
+ kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
+
+ struct PushConstants {
+ uint32_t inAOff, inBOff, outOff;
+ int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
+ } pushConsts {
+ 0, 0, 0,
+ ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
+ };
+
+ std::shared_ptr<kp::Algorithm> s_algo = nullptr;
+ if (!komputeManager()->hasAlgorithm(__func__)) {
+ s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
+ } else {
+ s_algo = komputeManager()->getAlgorithm(__func__);
+ s_algo->setTensors({inA, inB, out});
+ s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
+ s_algo->setPushConstants<PushConstants>({pushConsts});
+ s_algo->updateDescriptors(s_kompute_context->pool.get());
+ }
+ seq.record<kp::OpAlgoDispatch>(s_algo);
+}
+
static void ggml_vk_mul_mat_q6_k(
kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& inA,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q4_K:
return true;
default:
;
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
);
break;
+ case GGML_TYPE_Q4_K:
+ ggml_vk_mul_mat_q4_k(
+ seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
+ ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
+ );
+ break;
case GGML_TYPE_Q6_K:
ggml_vk_mul_mat_q6_k(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,