#include "shaderop_mul_mat_q4_1.h"
#include "shaderop_mul_mat_q6_k.h"
#include "shaderop_mul_mat_mat_f32.h"
+#include "shaderop_getrows_f32.h"
#include "shaderop_getrows_f16.h"
#include "shaderop_getrows_q4_0.h"
#include "shaderop_getrows_q4_1.h"
seq.record<kp::OpAlgoDispatch>(s_algo);
}
+template <typename... Args>
+static void ggml_vk_get_rows_f32(Args&&... args) {
+ const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f32_comp_spv,
+ kp::shader_data::op_getrows_f32_comp_spv_len);
+
+ ggml_vk_get_rows(spirv, "f32", sizeof(float), 0, std::forward<Args>(args)...);
+}
+
template <typename... Args>
static void ggml_vk_get_rows_f16(Args&&... args) {
const static auto spirv = getSpirvShader(kp::shader_data::op_getrows_f16_comp_spv,
return op->ne[3] == 1;
case GGML_OP_GET_ROWS:
switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
} break;
case GGML_OP_GET_ROWS:
{
- if (src0t == GGML_TYPE_F16) {
+ if (src0t == GGML_TYPE_F32) {
+ ggml_vk_get_rows_f32(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
+ } else if (src0t == GGML_TYPE_F16) {
ggml_vk_get_rows_f16(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));
} else if (src0t == GGML_TYPE_Q4_0) {
ggml_vk_get_rows_q4_0(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, nb01, nb1, ggml_nelements(src1));