]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml-cpu: Add CPU backend support for KleidiAI library (llama/11390)
authorCharles Xu <redacted>
Thu, 20 Feb 2025 13:06:51 +0000 (14:06 +0100)
committerGeorgi Gerganov <redacted>
Thu, 27 Feb 2025 06:55:36 +0000 (08:55 +0200)
* ggml-cpu: Add CPU backend support for KleidiAI library

* Add environmental variable GGML_KLEIDIAI_SME

* Add support for multithread LHS conversion

* Switch kernel selection order to dotprod and i8mm

* updates for review comments

* More updates for review comments

* Reorganize and rename KleidiAI files

* Move ggml-cpu-traits.h to source file

* Update cmake for SME build and add alignment for SME

* Remove append GGML_USE_CPU_KLEIDIAI to the GGML_CDEF_PUBLIC list

ggml/CMakeLists.txt
ggml/include/ggml-cpu.h
ggml/src/ggml-cpu/CMakeLists.txt
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ggml-cpu.cpp
ggml/src/ggml-cpu/kleidiai/kernels.cpp [new file with mode: 0644]
ggml/src/ggml-cpu/kleidiai/kernels.h [new file with mode: 0644]
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp [new file with mode: 0644]
ggml/src/ggml-cpu/kleidiai/kleidiai.h [new file with mode: 0644]

index 75b5ea3b439991c1a1292bdd5f05557a958d5aa5..fc5eac151b90cbd6ef7e04c4c80d0840a3b515af 100644 (file)
@@ -102,6 +102,7 @@ endif()
 
 option(GGML_CPU_HBM          "ggml: use memkind for CPU HBM" OFF)
 option(GGML_CPU_AARCH64      "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
+option(GGML_CPU_KLEIDIAI     "ggml: use KleidiAI optimized kernels if applicable" OFF)
 option(GGML_AVX              "ggml: enable AVX"              ${INS_ENB})
 option(GGML_AVX_VNNI         "ggml: enable AVX-VNNI"         OFF)
 option(GGML_AVX2             "ggml: enable AVX2"             ${INS_ENB})
index d23c6b262e202159ac6c890a5706fc1c01582250..9b8a697546ebde6e376c0e12708ef6667e32708b 100644 (file)
@@ -95,6 +95,7 @@ extern "C" {
     GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
     GGML_BACKEND_API int ggml_cpu_has_sve        (void);
     GGML_BACKEND_API int ggml_cpu_get_sve_cnt    (void);  // sve vector length in bytes
+    GGML_BACKEND_API int ggml_cpu_has_sme        (void);
     // other
     GGML_BACKEND_API int ggml_cpu_has_riscv_v    (void);
     GGML_BACKEND_API int ggml_cpu_has_vsx        (void);
index 26533e512aef3b554e647cdc799c69c18697bade..826d65cece01b2bcb207b936910c7c7b5f8dbc91 100644 (file)
@@ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 function(check_arm_feature tag code)
                     set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
                     set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
-                    check_cxx_source_runs(
-                        "${code}"
-                        GGML_MACHINE_SUPPORTS_${tag}
-                    )
+                    check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
                     if (GGML_MACHINE_SUPPORTS_${tag})
                         set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
                     else()
-                        set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
+                        set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
+                        check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
+                        if (GGML_MACHINE_SUPPORTS_no${tag})
+                            set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
+                        endif()
                     endif()
                     set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
                 endfunction()
@@ -126,6 +127,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
                 check_arm_feature(i8mm    "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
                 check_arm_feature(sve     "#include <arm_sve.h>\nint main()  { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
+                check_arm_feature(sme     "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
 
                 list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
             else()
@@ -150,7 +152,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             if (ARM_FEATURE_RESULT)
                 message(WARNING "Failed to get ARM features")
             else()
-                foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
+                foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
                     string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
                     if (NOT ${feature_pos} EQUAL -1)
                         message(STATUS "ARM feature ${feature} enabled")
@@ -312,6 +314,94 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
     endif()
 
+    if (GGML_CPU_KLEIDIAI)
+        message(STATUS "Using KleidiAI optimized kernels if applicable")
+
+        # Disable the KleidiAI tests
+        set(KLEIDIAI_BUILD_TESTS  OFF)
+
+        # Fetch KleidiAI sources:
+        include(FetchContent)
+        set(KLEIDIAI_COMMIT_TAG "v1.3.0")
+        set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
+        set(KLEIDIAI_ARCHIVE_MD5  "060bd2dc64642b091f461cc8dd7426d9")
+
+        if (POLICY CMP0135)
+            cmake_policy(SET CMP0135 NEW)
+        endif()
+
+        FetchContent_Declare(KleidiAI_Download
+            URL ${KLEIDIAI_DOWNLOAD_URL}
+            DOWNLOAD_EXTRACT_TIMESTAMP NEW
+            URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
+
+        FetchContent_MakeAvailable(KleidiAI_Download)
+        FetchContent_GetProperties(KleidiAI_Download
+            SOURCE_DIR  KLEIDIAI_SRC
+            POPULATED   KLEIDIAI_POPULATED)
+
+        if (NOT KLEIDIAI_POPULATED)
+            message(FATAL_ERROR "KleidiAI source downloaded failed.")
+        endif()
+
+        add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
+
+        # Remove kleidiai target after fetching it
+        if (TARGET kleidiai)
+            set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
+        endif()
+
+        list(APPEND GGML_CPU_SOURCES
+            ggml-cpu/kleidiai/kleidiai.cpp
+            ggml-cpu/kleidiai/kernels.cpp
+            ggml-cpu/kleidiai/kleidiai.h
+            ggml-cpu/kleidiai/kernels.h
+            )
+
+        # KleidiAI
+        include_directories(
+            ${KLEIDIAI_SRC}/
+            ${KLEIDIAI_SRC}/kai/
+            ${KLEIDIAI_SRC}/kai/ukernels/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
+
+        set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
+        if (NOT ARCH_FLAGS_TEMP)
+            string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}")
+        endif()
+        string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
+        string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
+        string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
+
+        set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
+
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
+
+        if (NOT DOTPROD_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
+        endif()
+
+        if (NOT I8MM_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
+        endif()
+
+        if (NOT SME_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
+            set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
+        endif()
+
+        set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
+        list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
+    endif()
+
     message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}")
     target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})
     target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
index dbef5df2111c6c4ca8588fee5cce9f184022f3c2..f27b981715aded35b6b40cf73ed1c71aeff6f440 100644 (file)
@@ -112,7 +112,8 @@ struct ggml_arm_arch_features_type {
     int has_i8mm;
     int has_sve;
     int sve_cnt;
-} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
+    int has_sme;
+} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
 #endif
 
 
@@ -2381,15 +2382,20 @@ bool ggml_is_numa(void) {
 #define HWCAP2_I8MM (1 << 13)
 #endif
 
+#if !defined(HWCAP2_SME)
+#define HWCAP2_SME (1 << 23)
+#endif
+
 static void ggml_init_arm_arch_features(void) {
 #if defined(__linux__) && defined(__aarch64__)
     uint32_t hwcap = getauxval(AT_HWCAP);
     uint32_t hwcap2 = getauxval(AT_HWCAP2);
 
-    ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
+    ggml_arm_arch_features.has_neon    = !!(hwcap & HWCAP_ASIMD);
     ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
-    ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
-    ggml_arm_arch_features.has_sve  = !!(hwcap & HWCAP_SVE);
+    ggml_arm_arch_features.has_i8mm    = !!(hwcap2 & HWCAP2_I8MM);
+    ggml_arm_arch_features.has_sve     = !!(hwcap & HWCAP_SVE);
+    ggml_arm_arch_features.has_sme     = !!(hwcap2 & HWCAP2_SME);
 
 #if defined(__ARM_FEATURE_SVE)
     ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
@@ -2412,6 +2418,11 @@ static void ggml_init_arm_arch_features(void) {
     }
     ggml_arm_arch_features.has_i8mm = oldp;
 
+    if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_sme = oldp;
+
     ggml_arm_arch_features.has_sve = 0;
     ggml_arm_arch_features.sve_cnt = 0;
 #else
@@ -2435,6 +2446,12 @@ static void ggml_init_arm_arch_features(void) {
     ggml_arm_arch_features.has_sve = 0;
     ggml_arm_arch_features.sve_cnt = 0;
 #endif
+
+#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
+    ggml_arm_arch_features.has_sme = 1;
+#else
+    ggml_arm_arch_features.has_sme = 0;
+#endif
 #endif
 }
 #endif
@@ -14442,6 +14459,14 @@ int ggml_cpu_get_sve_cnt(void) {
 #endif
 }
 
+int ggml_cpu_has_sme(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
+    return ggml_arm_arch_features.has_sme;
+#else
+    return 0;
+#endif
+}
+
 void ggml_cpu_init(void) {
     // needed to initialize f16 tables
     {
index be4eadcd021f6518766fe774627dbe6672440e47..d0ae10ee3762f63e5e3d65ec44f2463357512e03 100644 (file)
 #include "ggml-cpu-hbm.h"
 #endif
 
+#ifdef GGML_USE_CPU_KLEIDIAI
+#include "kleidiai/kleidiai.h"
+#endif
+
 #if defined(__APPLE__)
 #include <sys/types.h>
 #include <sys/sysctl.h>
@@ -39,6 +43,12 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
         }
 #endif
 
+#ifdef GGML_USE_CPU_KLEIDIAI
+        if (ggml_backend_cpu_kleidiai_buffer_type()) {
+            bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
+        }
+#endif
+
 #ifdef GGML_USE_CPU_AARCH64
         if (ggml_backend_cpu_aarch64_buffer_type()) {
             bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
@@ -538,6 +548,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
             static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
             features.push_back({ "SVE_CNT", sve_cnt.c_str() });
         }
+        if (ggml_cpu_has_sme()) {
+            features.push_back({ "SME", "1" });
+        }
         if (ggml_cpu_has_riscv_v()) {
             features.push_back({ "RISCV_V", "1" });
         }
@@ -559,6 +572,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
     #ifdef GGML_USE_OPENMP
         features.push_back({ "OPENMP", "1" });
     #endif
+    #ifdef GGML_USE_CPU_KLEIDIAI
+        features.push_back({ "KLEIDIAI", "1" });
+    #endif
     #ifdef GGML_USE_CPU_AARCH64
         features.push_back({ "AARCH64_REPACK", "1" });
     #endif
diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp
new file mode 100644 (file)
index 0000000..a8a59a8
--- /dev/null
@@ -0,0 +1,259 @@
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: MIT
+//
+
+// KleidiAI micro-kernels
+#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
+#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
+#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
+#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
+#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
+#include "kai_common.h"
+
+#include "kernels.h"
+
+#define NELEMS(x) sizeof(x) / sizeof(*x)
+static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
+#if defined(__ARM_FEATURE_SME)
+    {
+        /* SME GEMM */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+        },
+        /* SME GEMV */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+        },
+        /* .lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
+            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
+            /* .require_aligned_m_idx = */ true,
+        },
+        /* .rhs_info = */ {
+            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_SME,
+    },
+#endif
+#if defined(__APPLE__)
+#if defined(__ARM_FEATURE_DOTPROD)
+    {
+        /* DOTPROD GEMM */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+        },
+        /* DOTPROD GEMV */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+        },
+        /* .lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .require_aligned_m_idx = */ false,
+        },
+        /* .rhs_info = */ {
+            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
+    },
+#endif
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    {
+        /* i8mm GEMM */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+        },
+        /* i8mm GEMV */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+        },
+        /* .lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .require_aligned_m_idx = */ false,
+        },
+        /* .rhs_info = */ {
+            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+    },
+#endif
+#else
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+    {
+        /* i8mm GEMM */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+        },
+        /* i8mm GEMV */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+        },
+        /* .lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .require_aligned_m_idx = */ false,
+        },
+        /* .rhs_info = */ {
+            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+    },
+#endif
+#if defined(__ARM_FEATURE_DOTPROD)
+    {
+        /* DOTPROD GEMM */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+        },
+        /* DOTPROD GEMV */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+        },
+        /* .lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .require_aligned_m_idx = */ false,
+        },
+        /* .rhs_info = */ {
+            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
+    },
+#endif
+#endif
+};
+
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
+    ggml_kleidiai_kernels * kernels = nullptr;
+
+    for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
+        if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
+            kernels = &gemm_gemv_kernels[i];
+            break;
+        }
+    }
+
+    return kernels;
+}
diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h
new file mode 100644 (file)
index 0000000..a0b0d14
--- /dev/null
@@ -0,0 +1,61 @@
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+enum cpu_feature {
+    CPU_FEATURE_NONE    = 0,
+    CPU_FEATURE_DOTPROD = 1,
+    CPU_FEATURE_I8MM    = 2,
+    CPU_FEATURE_SVE     = 4,
+    CPU_FEATURE_SME     = 8
+};
+inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
+    lhs = static_cast<cpu_feature>(lhs | rhs);
+    return lhs;
+}
+inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {
+    return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));
+}
+
+struct kernel_info {
+    size_t (*get_m_step)(void);
+    size_t (*get_n_step)(void);
+    size_t (*get_mr)(void);
+    size_t (*get_nr)(void);
+    size_t (*get_kr)(void);
+    size_t (*get_sr)(void);
+    size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
+    size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
+    size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
+    size_t (*get_dst_size)(size_t m, size_t n);
+    void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
+                         float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
+};
+
+struct lhs_packing_info {
+    size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
+    size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
+    size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
+    void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
+                      size_t lhs_stride, void* lhs_packed);
+    bool require_aligned_m_idx;
+};
+
+struct rhs_packing_info {
+    size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
+    void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
+                      const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
+};
+
+struct ggml_kleidiai_kernels {
+    kernel_info gemm;
+    kernel_info gemv;
+    lhs_packing_info lhs_info;
+    rhs_packing_info rhs_info;
+
+    cpu_feature required_cpu;
+};
+
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp
new file mode 100644 (file)
index 0000000..66685fd
--- /dev/null
@@ -0,0 +1,287 @@
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: MIT
+//
+#include <arm_neon.h>
+#include <assert.h>
+#include <cfloat>
+#include <stdint.h>
+#include <string.h>
+#if defined(__linux__)
+#include <asm/hwcap.h>
+#include <sys/auxv.h>
+#elif defined(__APPLE__)
+#include <string_view>
+#include <sys/sysctl.h>
+#include <sys/types.h>
+#elif defined(_WIN32)
+#include <windows.h>
+#include <excpt.h>
+#endif
+
+#include "kleidiai.h"
+
+#include "ggml-cpu.h"
+#include "ggml-impl.h"
+#include "ggml-backend-impl.h"
+#include "ggml-threading.h"
+#include "ggml-cpu-traits.h"
+
+#include "kernels.h"
+
+#include "kai_common.h"
+
+#define GGML_COMMON_DECL_CPP
+#include "ggml-common.h"
+
+struct ggml_kleidiai_context {
+    ggml_kleidiai_kernels * kernels;
+} static ctx = { NULL };
+
+static void init_kleidiai_context(void) {
+
+    ggml_critical_section_start();
+    static bool initialized = false;
+
+    if (!initialized) {
+        initialized = true;
+        const char *env_var = getenv("GGML_KLEIDIAI_SME");
+        int sme_enabled = 0;
+
+        cpu_feature features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
+                                (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |
+                                (ggml_cpu_has_sve()         ? CPU_FEATURE_SVE     : CPU_FEATURE_NONE);
+
+        if (env_var) {
+            sme_enabled = atoi(env_var);
+        }
+
+        if (sme_enabled != 0) {
+            features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
+        }
+        ctx.kernels = ggml_kleidiai_select_kernels(features);
+    }
+    ggml_critical_section_end();
+}
+
+static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
+    GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
+    return tensor->ne[dim];
+}
+
+namespace ggml::cpu::kleidiai {
+class tensor_traits : public ggml::cpu::tensor_traits {
+    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+        GGML_ASSERT(ctx.kernels);
+        kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
+
+        size_t k = op->src[0]->ne[0];
+        size_t m = op->src[1]->ne[1];
+
+        size_t mr = kernel->get_mr();
+        size_t kr = kernel->get_kr();
+        size_t sr = kernel->get_sr();
+
+        size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
+
+        return true;
+    }
+
+    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
+        if (dst->op == GGML_OP_MUL_MAT) {
+            const ggml_tensor * src0 = dst->src[0];
+            const ggml_tensor * src1 = dst->src[1];
+
+            GGML_TENSOR_BINARY_OP_LOCALS
+
+            GGML_ASSERT(ctx.kernels);
+            kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
+            lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
+
+            GGML_ASSERT(kernel);
+
+            const int ith = params->ith;
+            const int nth = params->nth;
+
+            const size_t k = ne00;
+            const size_t m = ne11;
+            const size_t n = ne01;
+
+            const size_t n_step = kernel->get_n_step();
+            const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
+            const size_t n_start = ith * num_n_per_thread;
+
+            size_t n_to_process = num_n_per_thread;
+            if ((n_start + n_to_process) > n) {
+                n_to_process = n - n_start;
+            }
+
+            const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
+            uint8_t * lhs_packed       = (uint8_t*)params->wdata;
+            const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+
+            size_t mr = kernel->get_mr();
+            size_t kr = kernel->get_kr();
+            size_t sr = kernel->get_sr();
+
+            // Calculate number of columns to be processed per thread
+            const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
+            const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
+            const size_t m_start = ith * num_m_per_thread;
+            size_t m_to_process = num_m_per_thread;
+            if ((m_start + m_to_process) > m) {
+                m_to_process = m - m_start;
+            }
+
+            if(m_start < m) {
+                // Transform LHS
+                const size_t src_stride        = src1->nb[1];
+                const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
+                const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
+                void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
+
+                lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
+            }
+
+            ggml_barrier(params->threadpool);
+
+            // Perform the operation
+            const size_t dst_stride        = dst->nb[1];
+            const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
+            const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
+            const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
+            const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
+            const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
+            float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+
+            kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
+                               dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
+            return true;
+        }
+        return false;
+    }
+
+public:
+    int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
+        GGML_ASSERT(ctx.kernels);
+        const size_t n = tensor->ne[1];
+        const size_t k = tensor->ne[0];
+        size_t nr      = ctx.kernels->gemm.get_nr();
+        size_t kr      = ctx.kernels->gemm.get_kr();
+        size_t sr      = ctx.kernels->gemm.get_sr();
+
+#ifndef NDEBUG
+        const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
+        GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
+#endif
+        struct kai_rhs_pack_qs4cxs1s0_param params;
+        params.lhs_zero_point = 1;
+        params.rhs_zero_point = 8;
+        ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, &params);
+
+        return 0;
+
+        GGML_UNUSED(data_size);
+    }
+};
+
+static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
+    static tensor_traits traits;
+    return &traits;
+}
+}  // namespace ggml::cpu::kleidiai
+
+static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+    tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
+                                                       const void * data, size_t offset, size_t size) {
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
+
+    auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
+    auto OK            = tensor_traits->repack(tensor, data, size);
+
+    GGML_ASSERT(OK == 0);
+    GGML_UNUSED(buffer);
+}
+
+static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "CPU_KLEIDIAI";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+
+    if (buffer == nullptr) {
+        return nullptr;
+    }
+
+    buffer->buft              = buft;
+    buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
+    buffer->iface.set_tensor  = ggml_backend_cpu_kleidiai_buffer_set_tensor;
+    buffer->iface.get_tensor  = nullptr;
+    buffer->iface.cpy_tensor  = nullptr;
+    return buffer;
+}
+
+static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return TENSOR_ALIGNMENT;
+
+    GGML_UNUSED(buft);
+}
+
+namespace ggml::cpu::kleidiai {
+class extra_buffer_type : ggml::cpu::extra_buffer_type {
+    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+        if (    op->op == GGML_OP_MUL_MAT &&
+                op->src[0]->type == GGML_TYPE_Q4_0 &&
+                op->src[0]->buffer &&
+                (ggml_n_dims(op->src[0]) == 2) &&
+                op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
+                ) {
+            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
+                return false;
+            }
+            if (op->src[1]->type == GGML_TYPE_F32 &&
+                ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
+        if (op->op == GGML_OP_MUL_MAT) {
+            if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
+                return (ggml::cpu::tensor_traits *) op->src[0]->extra;
+            }
+        }
+        return nullptr;
+    }
+};
+}  // namespace ggml::cpu::kleidiai
+
+ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
+    static ggml::cpu::kleidiai::extra_buffer_type ctx;
+    static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
+        /* .iface    = */ {
+                           /* .get_name         = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
+                           /* .alloc_buffer     = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
+                           /* .get_alignment    = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
+                           /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
+                           /* .get_alloc_size   = */ nullptr,  // defaults to ggml_nbytes
+                           /* .is_host          = */ nullptr,
+                           },
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+        /* .context = */ &ctx,
+    };
+
+    init_kleidiai_context();
+
+    return &ggml_backend_cpu_buffer_type_kleidiai;
+}
diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.h b/ggml/src/ggml-cpu/kleidiai/kleidiai.h
new file mode 100644 (file)
index 0000000..38eac58
--- /dev/null
@@ -0,0 +1,17 @@
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ggml-alloc.h"
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void);
+
+#ifdef  __cplusplus
+}
+#endif