]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel (#13053)
authorDan Johansson <redacted>
Mon, 12 May 2025 11:06:19 +0000 (13:06 +0200)
committerGitHub <redacted>
Mon, 12 May 2025 11:06:19 +0000 (13:06 +0200)
* ggml-cpu: Integrate fp32=bf16xbf16 SME KleidiAI kernel

Signed-off-by: Dan Johansson <redacted>
* * code review fixes

Signed-off-by: Dan Johansson <redacted>
* * adds a comment that clarifies barrier usage

Signed-off-by: Dan Johansson <redacted>
---------

Signed-off-by: Dan Johansson <redacted>
Co-authored-by: Charles Xu <redacted>
ggml/src/ggml-cpu/CMakeLists.txt
ggml/src/ggml-cpu/kleidiai/kernels.cpp
ggml/src/ggml-cpu/kleidiai/kernels.h
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

index 9a3085befc47645b1d95ec9be3cd01fc90a2800c..bdaec2881ddc32b0eebf920bfecb17098829e04c 100644 (file)
@@ -428,6 +428,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             ${KLEIDIAI_SRC}/kai/ukernels/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
 
         set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
@@ -438,17 +439,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
         string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
 
-        set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
+        set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
 
-        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
-        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
-        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
-        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
 
         if (NOT DOTPROD_ENABLED MATCHES -1)
-            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
-            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
-            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
         endif()
 
         if (NOT I8MM_ENABLED MATCHES -1)
@@ -456,9 +459,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         endif()
 
         if (NOT SME_ENABLED MATCHES -1)
-            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
-            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
-            set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
+            list(APPEND GGML_KLEIDIAI_SOURCES
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c)
+            set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
         endif()
 
         set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
index aacc2bb5ee0f243b6ead19777a5bbbb49d7156e5..910fd0ee4e7430f455451c1bf86eaa8e07b9e03b 100644 (file)
@@ -4,16 +4,22 @@
 
 // KleidiAI micro-kernels
 #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
-#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
-#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
-#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
-#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
+#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
+
+#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
+#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
+#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
+
+#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
+#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
+#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
+
 #include "kai_common.h"
 
 #include "kernels.h"
@@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
         },
         /* .required_cpu       = */ CPU_FEATURE_SME,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_Q4_0,
+        /* .op_type            = */ GGML_TYPE_F32,
+    },
+    {
+        /* SME GEMM */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+        },
+        /* SME GEMV */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+        },
+        /* .lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
+            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
+            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
+            /* .pack_func             = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
+        },
+        /* .rhs_info = */ {
+            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+            /* .pack_func   = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_SME,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_F16,
+        /* .op_type            = */ GGML_TYPE_F32,
     },
 #endif
 #if defined(__APPLE__)
@@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_Q4_0,
+        /* .op_type            = */ GGML_TYPE_F32,
     },
 #endif
 #if defined(__ARM_FEATURE_MATMUL_INT8)
@@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_Q4_0,
+        /* .op_type            = */ GGML_TYPE_F32,
     },
 #endif
 #else
@@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_Q4_0,
+        /* .op_type            = */ GGML_TYPE_F32,
     },
 #endif
 #if defined(__ARM_FEATURE_DOTPROD)
@@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_Q4_0,
+        /* .op_type            = */ GGML_TYPE_F32,
     },
 #endif
 #endif
 };
 
-ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
+    ggml_kleidiai_kernels * kernel = nullptr;
+
+    if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
+        for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
+            if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
+                gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
+                gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
+                gemm_gemv_kernels[i].op_type  == tensor->type) {
+                kernel = &gemm_gemv_kernels[i];
+                break;
+            }
+        }
+    }
+
+    return kernel;
+}
+
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
     ggml_kleidiai_kernels * kernels = nullptr;
 
     for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
index 2ffe97eb42feb78b171ed2a93c2f094990b3d6b9..5ac02bda7c0187ca3a5e9f4d90e82af3301cc605 100644 (file)
@@ -4,6 +4,9 @@
 
 #pragma once
 
+#include <functional>
+#include "ggml.h"
+
 enum cpu_feature {
     CPU_FEATURE_NONE    = 0,
     CPU_FEATURE_DOTPROD = 1,
@@ -26,26 +29,53 @@ struct kernel_info {
     size_t (*get_nr)(void);
     size_t (*get_kr)(void);
     size_t (*get_sr)(void);
-    size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
-    size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
+    std::variant<
+        std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
+        std::function<size_t(size_t m_idx, size_t k)>
+    > get_lhs_offset;
+    std::variant<
+        std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
+        std::function<size_t(size_t n_idx, size_t k)>
+    > get_rhs_packed_offset;
     size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
     size_t (*get_dst_size)(size_t m, size_t n);
-    void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
-                         float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
+    std::variant<
+        std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
+            float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
+        std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
+            size_t dst_stride_col, float clamp_min, float clamp_max)>
+    > run_kernel;
 };
 
 struct lhs_packing_info {
     size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
-    size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
-    size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
-    void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
-                      size_t lhs_stride, void* lhs_packed);
+    std::variant<
+        std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
+        std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
+    > get_packed_offset;
+    std::variant<
+        std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
+        std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
+    > packed_size;
+    std::variant<
+        std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
+            size_t lhs_stride, void* lhs_packed)>,
+        std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
+        void* lhs_packed)>
+    > pack_func;
 };
 
 struct rhs_packing_info {
-    size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
-    void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
-                      const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
+    std::variant<
+        std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
+        std::function<size_t(size_t n, size_t k)>
+    > packed_size;
+    std::variant<
+        std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
+            const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
+        std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
+            const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
+    > pack_func;
 };
 
 struct ggml_kleidiai_kernels {
@@ -55,6 +85,10 @@ struct ggml_kleidiai_kernels {
     rhs_packing_info rhs_info;
 
     cpu_feature required_cpu;
+    ggml_type lhs_type;
+    ggml_type rhs_type;
+    ggml_type op_type;
 };
 
-ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
index 4e89ca0faa2a761f0b062f6cca963af22a1fb874..f3dffdd6bf5d3ab6f47472a037c599974b6cf75f 100644 (file)
@@ -34,8 +34,9 @@
 #include "ggml-common.h"
 
 struct ggml_kleidiai_context {
+    cpu_feature features;
     ggml_kleidiai_kernels * kernels;
-} static ctx = { NULL };
+} static ctx = { CPU_FEATURE_NONE, NULL };
 
 static void init_kleidiai_context(void) {
 
@@ -47,18 +48,18 @@ static void init_kleidiai_context(void) {
         const char *env_var = getenv("GGML_KLEIDIAI_SME");
         int sme_enabled = 0;
 
-        cpu_feature features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
-                                (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |
-                                (ggml_cpu_has_sve()         ? CPU_FEATURE_SVE     : CPU_FEATURE_NONE);
+        ctx.features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
+                        (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |
+                        (ggml_cpu_has_sve()         ? CPU_FEATURE_SVE     : CPU_FEATURE_NONE);
 
         if (env_var) {
             sme_enabled = atoi(env_var);
         }
 
         if (sme_enabled != 0) {
-            features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
+            ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
         }
-        ctx.kernels = ggml_kleidiai_select_kernels(features);
+        ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
     }
     ggml_critical_section_end();
 }
@@ -68,95 +69,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
     return tensor->ne[dim];
 }
 
+template<typename Ret, typename Variant, typename... Args>
+static Ret variant_call(const Variant & var, Args&&... args) {
+    return std::visit([&](auto&& func) -> Ret {
+        if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
+            return func(std::forward<Args>(args)...);
+        } else {
+            throw std::runtime_error("Invalid function type in variant_call");
+        }
+    }, var);
+}
+
 namespace ggml::cpu::kleidiai {
+
+static size_t round_down(size_t x, size_t y) {
+    return y == 0 ? x : x - (x % y);
+}
+
+static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
+    size_t src_stride = rhs_stride / sizeof(uint16_t);
+    size_t dst_stride = n;
+
+    for (size_t k_idx = 0; k_idx < k; ++k_idx) {
+        for (size_t n_idx = 0; n_idx < n; ++n_idx) {
+            uint16_t v = *(src + k_idx + n_idx * src_stride);
+            *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
+        }
+    }
+}
+
 class tensor_traits : public ggml::cpu::tensor_traits {
     bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
-        GGML_ASSERT(ctx.kernels);
-        kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
+        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
+        GGML_ASSERT(kernels);
+        kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
 
         size_t k = op->src[0]->ne[0];
+        size_t n = op->src[0]->ne[1];
         size_t m = op->src[1]->ne[1];
 
         size_t mr = kernel->get_mr();
         size_t kr = kernel->get_kr();
         size_t sr = kernel->get_sr();
 
-        size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
+        if (kernels->rhs_type == GGML_TYPE_Q4_0) {
+            size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
+        } else if (kernels->rhs_type == GGML_TYPE_F16) {
+            size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
+                   variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
+                   k * n * sizeof(float) + n * sizeof(float);
+        } else {
+            GGML_ASSERT(false);
+        }
 
         return true;
     }
 
+
     bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
         if (dst->op == GGML_OP_MUL_MAT) {
-            const ggml_tensor * src0 = dst->src[0];
-            const ggml_tensor * src1 = dst->src[1];
+            if (dst->src[0]->type == GGML_TYPE_Q4_0) {
+                return compute_forward_q4_0(params, dst);
+            } else if (dst->src[0]->type == GGML_TYPE_F16) {
+                return compute_forward_kv_cache(params, dst);
+            }
+        }
+        return false;
+    }
 
-            GGML_TENSOR_BINARY_OP_LOCALS
+    bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
+        static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
 
-            GGML_ASSERT(ctx.kernels);
-            kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
-            lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
+        const ggml_tensor * src0 = dst->src[0];
+        const ggml_tensor * src1 = dst->src[1];
 
-            GGML_ASSERT(kernel);
+        GGML_TENSOR_BINARY_OP_LOCALS
 
-            const int ith = params->ith;
-            const int nth = params->nth;
+        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
+        GGML_ASSERT(kernels);
 
-            const size_t k = ne00;
-            const size_t m = ne11;
-            const size_t n = ne01;
+        kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
+        GGML_ASSERT(kernel);
 
-            const size_t n_step = kernel->get_n_step();
-            const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
-            const size_t n_start = ith * num_n_per_thread;
+        const int nth = params->nth;
+        const int ith = params->ith;
 
-            size_t n_to_process = num_n_per_thread;
-            if ((n_start + n_to_process) > n) {
-                n_to_process = n - n_start;
-            }
+        const int64_t lhs_batch_size0 = ne12;
+        const int64_t rhs_batch_size0 = ne02;
+        const int64_t batch_size      = rhs_batch_size0;
+
+        const int64_t r = lhs_batch_size0 / rhs_batch_size0;
+
+        const int64_t m = ne11 * r;
+        const int64_t n = ne01;
+        const int64_t k = ne00;
+
+        const size_t lhs_stride = src1->nb[1];
+        const size_t rhs_stride = src0->nb[1];
+        const size_t dst_stride = dst->nb[1];
+
+        const int64_t mr = static_cast<int64_t>(kernel->get_mr());
+        const int64_t nr = static_cast<int64_t>(kernel->get_nr());
+        const int64_t kr = static_cast<int64_t>(kernel->get_kr());
+        const int64_t sr = static_cast<int64_t>(kernel->get_sr());
+
+        const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
+        const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
+        const size_t kxn_size        = k * n * sizeof(float);
+        const size_t bias_size       = n * sizeof(float);
+
+        const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
+        GGML_ASSERT(wsize_required <= params->wsize);
+
+        uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
+        uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
+        uint8_t * rhs_kxn    = rhs_packed + rhs_packed_size;
+        uint8_t * bias       = rhs_kxn + kxn_size;
+
+        for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
+            const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
+            const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
+            uint8_t * dst_batch       = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
 
-            const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
-            uint8_t * lhs_packed       = (uint8_t*)params->wdata;
-            const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+            // LHS packing
+            {
+                const int64_t m_roundup_mr = kai_roundup(m, mr);
+                const int64_t num_threads  = KAI_MIN(m_roundup_mr / mr, nth);
 
-            size_t mr = kernel->get_mr();
-            size_t kr = kernel->get_kr();
-            size_t sr = kernel->get_sr();
+                if (ith < num_threads) {
+                    const int64_t num_m_per_thread0   = round_down(m_roundup_mr / num_threads, mr);
+                    const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
 
-            // Calculate number of columns to be processed per thread
-            const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
-            const size_t m_start = ith * num_m_per_thread;
-            size_t m_to_process = num_m_per_thread;
-            if ((m_start + m_to_process) > m) {
-                m_to_process = m - m_start;
+                    const int64_t m_start          = ith * num_m_per_thread0;
+                    const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
+
+                    const size_t lhs_offset        = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
+                    const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
+
+                    const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
+                    void * dst_ptr       = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
+
+                    variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
+                }
             }
 
-            if(m_start < m) {
-                // Transform LHS
-                const size_t src_stride        = src1->nb[1];
-                const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
-                const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
-                void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
+            // RHS packing
+            if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
+                // First thread to reach this point handles RHS packing
+                memset(bias, 0, n * sizeof(float));
+                transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
+                                        reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
 
-                lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
+                variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
+                             rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
             }
 
             ggml_barrier(params->threadpool);
 
-            // Perform the operation
-            const size_t dst_stride        = dst->nb[1];
-            const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
-            const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
-            const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
-            const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
-            const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
-            float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
-
-            kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
-                               dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
-            return true;
+            first_to_arrive.clear(std::memory_order_release);
+
+            // Perform the matmul
+            {
+                const int64_t m_to_process = m;
+                const int64_t m_start      = 0;
+
+                const int64_t n_step      = static_cast<int64_t>(kernel->get_n_step());
+                const int64_t num_threads = KAI_MIN(n / n_step, nth);
+
+                if (ith < num_threads) {
+                    const int64_t num_n_per_thread0   = round_down(n / num_threads, n_step);
+                    const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
+
+                    const int64_t n_start      = ith * num_n_per_thread0;
+                    const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
+
+                    const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
+                    const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
+                    const size_t dst_offset        = kernel->get_dst_offset(m_start, n_start, dst_stride);
+
+                    const void * lhs_ptr = lhs_packed + lhs_packed_offset;
+                    const void * rhs_ptr = rhs_packed + rhs_packed_offset;
+                    float * dst_ptr      = reinterpret_cast<float *>(dst_batch + dst_offset);
+
+                    variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
+                }
+            }
+
+            if (batch_idx != batch_size - 1) {
+                // This barrier is necessary when the batch size is larger than 1. While processing a batch,
+                // the work data buffer (params->wdata) is used as temporary storage which means that only
+                // a single batch can be processed at any given time. No barrier is needed for the last
+                // batch since GGML inserts a barrier between the execution of every operator.
+                ggml_barrier(params->threadpool);
+            }
         }
-        return false;
+
+        return true;
+    }
+
+    bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        const ggml_tensor * src0 = dst->src[0];
+        const ggml_tensor * src1 = dst->src[1];
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
+        GGML_ASSERT(kernels);
+
+        kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
+        lhs_packing_info * lhs_info = &kernels->lhs_info;
+
+        GGML_ASSERT(kernel);
+
+        const int ith = params->ith;
+        const int nth = params->nth;
+
+        const size_t k = ne00;
+        const size_t m = ne11;
+        const size_t n = ne01;
+
+        size_t mr = kernel->get_mr();
+        size_t kr = kernel->get_kr();
+        size_t sr = kernel->get_sr();
+
+        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
+        uint8_t * lhs_packed       = (uint8_t*)params->wdata;
+        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+
+        const size_t n_step = kernel->get_n_step();
+        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
+        const size_t n_start = ith * num_n_per_thread;
+
+        size_t n_to_process = num_n_per_thread;
+        if ((n_start + n_to_process) > n) {
+            n_to_process = n - n_start;
+        }
+
+        // Calculate number of columns to be processed per thread
+        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
+        const size_t m_start = ith * num_m_per_thread;
+        size_t m_to_process = num_m_per_thread;
+        if ((m_start + m_to_process) > m) {
+            m_to_process = m - m_start;
+        }
+
+        if (m_start < m) {
+            // Transform LHS
+            const size_t src_stride        = src1->nb[1];
+            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
+            const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
+            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
+
+            variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
+        }
+
+        ggml_barrier(params->threadpool);
+
+        // Perform the operation
+        const size_t dst_stride        = dst->nb[1];
+        const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
+        const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
+        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
+        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
+        const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
+        float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+
+        variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
+                           sizeof(float), -FLT_MAX, FLT_MAX);
+
+        return true;
     }
 
 public:
@@ -169,13 +350,13 @@ public:
         size_t sr      = ctx.kernels->gemm.get_sr();
 
 #ifndef NDEBUG
-        const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
+        const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
         GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
 #endif
         struct kai_rhs_pack_qs4cxs1s0_param params;
         params.lhs_zero_point = 1;
         params.rhs_zero_point = 8;
-        ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, &params);
+        variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
 
         return 0;
 
@@ -189,7 +370,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
 }
 }  // namespace ggml::cpu::kleidiai
 
-GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
     tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
 
     GGML_UNUSED(buffer);
@@ -238,12 +419,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
 namespace ggml::cpu::kleidiai {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
-        if (    op->op == GGML_OP_MUL_MAT &&
-                op->src[0]->type == GGML_TYPE_Q4_0 &&
-                op->src[0]->buffer &&
-                (ggml_n_dims(op->src[0]) == 2) &&
-                op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
-                ) {
+        if (op->op == GGML_OP_MUL_MAT &&
+            op->src[0]->type == GGML_TYPE_Q4_0 &&
+            op->src[0]->buffer &&
+            (ggml_n_dims(op->src[0]) == 2) &&
+            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
             if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
                 return false;
             }
@@ -260,6 +440,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
             if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
                 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
             }
+            else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
+                     op->src[0]->op == GGML_OP_VIEW &&
+                     (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op ==  GGML_OP_SOFT_MAX) &&
+                     op->src[1]->ne[1] > 1) {
+                if ((op->src[0]->nb[0] != 2) ||
+                    (op->src[1]->nb[0] != 4) ||
+                    (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
+                    (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
+                    return nullptr;
+                }
+
+                return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
+            }
         }
         return nullptr;
     }