option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
+option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF)
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
GGML_BACKEND_API int ggml_cpu_has_sve (void);
GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
+ GGML_BACKEND_API int ggml_cpu_has_sme (void);
// other
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
GGML_BACKEND_API int ggml_cpu_has_vsx (void);
function(check_arm_feature tag code)
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
- check_cxx_source_runs(
- "${code}"
- GGML_MACHINE_SUPPORTS_${tag}
- )
+ check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
if (GGML_MACHINE_SUPPORTS_${tag})
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
else()
- set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
+ set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
+ check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
+ if (GGML_MACHINE_SUPPORTS_no${tag})
+ set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
+ endif()
endif()
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
endfunction()
check_arm_feature(dotprod "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
check_arm_feature(i8mm "#include <arm_neon.h>\nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
check_arm_feature(sve "#include <arm_sve.h>\nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
+ check_arm_feature(sme "#include <arm_sme.h>\n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
else()
if (ARM_FEATURE_RESULT)
message(WARNING "Failed to get ARM features")
else()
- foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
+ foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
if (NOT ${feature_pos} EQUAL -1)
message(STATUS "ARM feature ${feature} enabled")
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
endif()
+ if (GGML_CPU_KLEIDIAI)
+ message(STATUS "Using KleidiAI optimized kernels if applicable")
+
+ # Disable the KleidiAI tests
+ set(KLEIDIAI_BUILD_TESTS OFF)
+
+ # Fetch KleidiAI sources:
+ include(FetchContent)
+ set(KLEIDIAI_COMMIT_TAG "v1.3.0")
+ set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
+ set(KLEIDIAI_ARCHIVE_MD5 "060bd2dc64642b091f461cc8dd7426d9")
+
+ if (POLICY CMP0135)
+ cmake_policy(SET CMP0135 NEW)
+ endif()
+
+ FetchContent_Declare(KleidiAI_Download
+ URL ${KLEIDIAI_DOWNLOAD_URL}
+ DOWNLOAD_EXTRACT_TIMESTAMP NEW
+ URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
+
+ FetchContent_MakeAvailable(KleidiAI_Download)
+ FetchContent_GetProperties(KleidiAI_Download
+ SOURCE_DIR KLEIDIAI_SRC
+ POPULATED KLEIDIAI_POPULATED)
+
+ if (NOT KLEIDIAI_POPULATED)
+ message(FATAL_ERROR "KleidiAI source downloaded failed.")
+ endif()
+
+ add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
+
+ # Remove kleidiai target after fetching it
+ if (TARGET kleidiai)
+ set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
+ endif()
+
+ list(APPEND GGML_CPU_SOURCES
+ ggml-cpu/kleidiai/kleidiai.cpp
+ ggml-cpu/kleidiai/kernels.cpp
+ ggml-cpu/kleidiai/kleidiai.h
+ ggml-cpu/kleidiai/kernels.h
+ )
+
+ # KleidiAI
+ include_directories(
+ ${KLEIDIAI_SRC}/
+ ${KLEIDIAI_SRC}/kai/
+ ${KLEIDIAI_SRC}/kai/ukernels/
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
+
+ set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
+ if (NOT ARCH_FLAGS_TEMP)
+ string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}")
+ endif()
+ string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
+ string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
+ string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
+
+ set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
+
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
+
+ if (NOT DOTPROD_ENABLED MATCHES -1)
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
+ endif()
+
+ if (NOT I8MM_ENABLED MATCHES -1)
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
+ endif()
+
+ if (NOT SME_ENABLED MATCHES -1)
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
+ list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
+ set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
+ endif()
+
+ set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
+ list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
+ endif()
+
message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}")
target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})
target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
int has_i8mm;
int has_sve;
int sve_cnt;
-} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
+ int has_sme;
+} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
#endif
#define HWCAP2_I8MM (1 << 13)
#endif
+#if !defined(HWCAP2_SME)
+#define HWCAP2_SME (1 << 23)
+#endif
+
static void ggml_init_arm_arch_features(void) {
#if defined(__linux__) && defined(__aarch64__)
uint32_t hwcap = getauxval(AT_HWCAP);
uint32_t hwcap2 = getauxval(AT_HWCAP2);
- ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
+ ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
- ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
- ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
+ ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
+ ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
+ ggml_arm_arch_features.has_sme = !!(hwcap2 & HWCAP2_SME);
#if defined(__ARM_FEATURE_SVE)
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
}
ggml_arm_arch_features.has_i8mm = oldp;
+ if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
+ oldp = 0;
+ }
+ ggml_arm_arch_features.has_sme = oldp;
+
ggml_arm_arch_features.has_sve = 0;
ggml_arm_arch_features.sve_cnt = 0;
#else
ggml_arm_arch_features.has_sve = 0;
ggml_arm_arch_features.sve_cnt = 0;
#endif
+
+#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
+ ggml_arm_arch_features.has_sme = 1;
+#else
+ ggml_arm_arch_features.has_sme = 0;
+#endif
#endif
}
#endif
#endif
}
+int ggml_cpu_has_sme(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
+ return ggml_arm_arch_features.has_sme;
+#else
+ return 0;
+#endif
+}
+
void ggml_cpu_init(void) {
// needed to initialize f16 tables
{
#include "ggml-cpu-hbm.h"
#endif
+#ifdef GGML_USE_CPU_KLEIDIAI
+#include "kleidiai/kleidiai.h"
+#endif
+
#if defined(__APPLE__)
#include <sys/types.h>
#include <sys/sysctl.h>
}
#endif
+#ifdef GGML_USE_CPU_KLEIDIAI
+ if (ggml_backend_cpu_kleidiai_buffer_type()) {
+ bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
+ }
+#endif
+
#ifdef GGML_USE_CPU_AARCH64
if (ggml_backend_cpu_aarch64_buffer_type()) {
bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
}
+ if (ggml_cpu_has_sme()) {
+ features.push_back({ "SME", "1" });
+ }
if (ggml_cpu_has_riscv_v()) {
features.push_back({ "RISCV_V", "1" });
}
#ifdef GGML_USE_OPENMP
features.push_back({ "OPENMP", "1" });
#endif
+ #ifdef GGML_USE_CPU_KLEIDIAI
+ features.push_back({ "KLEIDIAI", "1" });
+ #endif
#ifdef GGML_USE_CPU_AARCH64
features.push_back({ "AARCH64_REPACK", "1" });
#endif
--- /dev/null
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: MIT
+//
+
+// KleidiAI micro-kernels
+#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
+#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
+#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
+#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
+#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
+#include "kai_common.h"
+
+#include "kernels.h"
+
+#define NELEMS(x) sizeof(x) / sizeof(*x)
+static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
+#if defined(__ARM_FEATURE_SME)
+ {
+ /* SME GEMM */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+ },
+ /* SME GEMV */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+ },
+ /* .lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
+ /* .require_aligned_m_idx = */ true,
+ },
+ /* .rhs_info = */ {
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+ },
+ /* .required_cpu = */ CPU_FEATURE_SME,
+ },
+#endif
+#if defined(__APPLE__)
+#if defined(__ARM_FEATURE_DOTPROD)
+ {
+ /* DOTPROD GEMM */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ },
+ /* DOTPROD GEMV */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ },
+ /* .lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+ /* .require_aligned_m_idx = */ false,
+ },
+ /* .rhs_info = */ {
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ },
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD,
+ },
+#endif
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ {
+ /* i8mm GEMM */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ },
+ /* i8mm GEMV */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ },
+ /* .lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+ /* .require_aligned_m_idx = */ false,
+ },
+ /* .rhs_info = */ {
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ },
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+ },
+#endif
+#else
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ {
+ /* i8mm GEMM */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+ },
+ /* i8mm GEMV */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+ },
+ /* .lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+ /* .require_aligned_m_idx = */ false,
+ },
+ /* .rhs_info = */ {
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ },
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+ },
+#endif
+#if defined(__ARM_FEATURE_DOTPROD)
+ {
+ /* DOTPROD GEMM */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+ },
+ /* DOTPROD GEMV */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+ },
+ /* .lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
+ /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+ /* .require_aligned_m_idx = */ false,
+ },
+ /* .rhs_info = */ {
+ /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ },
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD,
+ },
+#endif
+#endif
+};
+
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) {
+ ggml_kleidiai_kernels * kernels = nullptr;
+
+ for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
+ if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
+ kernels = &gemm_gemv_kernels[i];
+ break;
+ }
+ }
+
+ return kernels;
+}
--- /dev/null
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+enum cpu_feature {
+ CPU_FEATURE_NONE = 0,
+ CPU_FEATURE_DOTPROD = 1,
+ CPU_FEATURE_I8MM = 2,
+ CPU_FEATURE_SVE = 4,
+ CPU_FEATURE_SME = 8
+};
+inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
+ lhs = static_cast<cpu_feature>(lhs | rhs);
+ return lhs;
+}
+inline cpu_feature operator|(cpu_feature lhs, cpu_feature rhs) {
+ return static_cast<cpu_feature>(static_cast<int>(lhs) | static_cast<int>(rhs));
+}
+
+struct kernel_info {
+ size_t (*get_m_step)(void);
+ size_t (*get_n_step)(void);
+ size_t (*get_mr)(void);
+ size_t (*get_nr)(void);
+ size_t (*get_kr)(void);
+ size_t (*get_sr)(void);
+ size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl);
+ size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl);
+ size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
+ size_t (*get_dst_size)(size_t m, size_t n);
+ void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
+ float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max);
+};
+
+struct lhs_packing_info {
+ size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
+ size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
+ size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
+ void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
+ size_t lhs_stride, void* lhs_packed);
+ bool require_aligned_m_idx;
+};
+
+struct rhs_packing_info {
+ size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
+ void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
+ const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params);
+};
+
+struct ggml_kleidiai_kernels {
+ kernel_info gemm;
+ kernel_info gemv;
+ lhs_packing_info lhs_info;
+ rhs_packing_info rhs_info;
+
+ cpu_feature required_cpu;
+};
+
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features);
--- /dev/null
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: MIT
+//
+#include <arm_neon.h>
+#include <assert.h>
+#include <cfloat>
+#include <stdint.h>
+#include <string.h>
+#if defined(__linux__)
+#include <asm/hwcap.h>
+#include <sys/auxv.h>
+#elif defined(__APPLE__)
+#include <string_view>
+#include <sys/sysctl.h>
+#include <sys/types.h>
+#elif defined(_WIN32)
+#include <windows.h>
+#include <excpt.h>
+#endif
+
+#include "kleidiai.h"
+
+#include "ggml-cpu.h"
+#include "ggml-impl.h"
+#include "ggml-backend-impl.h"
+#include "ggml-threading.h"
+#include "ggml-cpu-traits.h"
+
+#include "kernels.h"
+
+#include "kai_common.h"
+
+#define GGML_COMMON_DECL_CPP
+#include "ggml-common.h"
+
+struct ggml_kleidiai_context {
+ ggml_kleidiai_kernels * kernels;
+} static ctx = { NULL };
+
+static void init_kleidiai_context(void) {
+
+ ggml_critical_section_start();
+ static bool initialized = false;
+
+ if (!initialized) {
+ initialized = true;
+ const char *env_var = getenv("GGML_KLEIDIAI_SME");
+ int sme_enabled = 0;
+
+ cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
+ (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
+ (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
+
+ if (env_var) {
+ sme_enabled = atoi(env_var);
+ }
+
+ if (sme_enabled != 0) {
+ features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
+ }
+ ctx.kernels = ggml_kleidiai_select_kernels(features);
+ }
+ ggml_critical_section_end();
+}
+
+static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
+ GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
+ return tensor->ne[dim];
+}
+
+namespace ggml::cpu::kleidiai {
+class tensor_traits : public ggml::cpu::tensor_traits {
+ bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+ GGML_ASSERT(ctx.kernels);
+ kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
+
+ size_t k = op->src[0]->ne[0];
+ size_t m = op->src[1]->ne[1];
+
+ size_t mr = kernel->get_mr();
+ size_t kr = kernel->get_kr();
+ size_t sr = kernel->get_sr();
+
+ size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr);
+
+ return true;
+ }
+
+ bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
+ if (dst->op == GGML_OP_MUL_MAT) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(ctx.kernels);
+ kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm;
+ lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
+
+ GGML_ASSERT(kernel);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const size_t k = ne00;
+ const size_t m = ne11;
+ const size_t n = ne01;
+
+ const size_t n_step = kernel->get_n_step();
+ const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
+ const size_t n_start = ith * num_n_per_thread;
+
+ size_t n_to_process = num_n_per_thread;
+ if ((n_start + n_to_process) > n) {
+ n_to_process = n - n_start;
+ }
+
+ const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
+ uint8_t * lhs_packed = (uint8_t*)params->wdata;
+ const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+
+ size_t mr = kernel->get_mr();
+ size_t kr = kernel->get_kr();
+ size_t sr = kernel->get_sr();
+
+ // Calculate number of columns to be processed per thread
+ const bool use_multithread = lhs_info->require_aligned_m_idx && m <= mr ? false : true;
+ const size_t num_m_per_thread = use_multithread ? kai_roundup(m, nth) / nth : m;
+ const size_t m_start = ith * num_m_per_thread;
+ size_t m_to_process = num_m_per_thread;
+ if ((m_start + m_to_process) > m) {
+ m_to_process = m - m_start;
+ }
+
+ if(m_start < m) {
+ // Transform LHS
+ const size_t src_stride = src1->nb[1];
+ const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
+ const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr);
+ void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
+
+ lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
+ }
+
+ ggml_barrier(params->threadpool);
+
+ // Perform the operation
+ const size_t dst_stride = dst->nb[1];
+ const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr);
+ const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0);
+ const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
+ const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
+ const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
+ float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+
+ kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr,
+ dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
+ return true;
+ }
+ return false;
+ }
+
+public:
+ int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
+ GGML_ASSERT(ctx.kernels);
+ const size_t n = tensor->ne[1];
+ const size_t k = tensor->ne[0];
+ size_t nr = ctx.kernels->gemm.get_nr();
+ size_t kr = ctx.kernels->gemm.get_kr();
+ size_t sr = ctx.kernels->gemm.get_sr();
+
+#ifndef NDEBUG
+ const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0);
+ GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
+#endif
+ struct kai_rhs_pack_qs4cxs1s0_param params;
+ params.lhs_zero_point = 1;
+ params.rhs_zero_point = 8;
+ ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms);
+
+ return 0;
+
+ GGML_UNUSED(data_size);
+ }
+};
+
+static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
+ static tensor_traits traits;
+ return &traits;
+}
+} // namespace ggml::cpu::kleidiai
+
+static void ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
+ tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
+
+ GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
+ const void * data, size_t offset, size_t size) {
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(size == ggml_nbytes(tensor));
+
+ auto tensor_traits = (ggml::cpu::kleidiai::tensor_traits *) tensor->extra;
+ auto OK = tensor_traits->repack(tensor, data, size);
+
+ GGML_ASSERT(OK == 0);
+ GGML_UNUSED(buffer);
+}
+
+static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+ return "CPU_KLEIDIAI";
+
+ GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+
+ if (buffer == nullptr) {
+ return nullptr;
+ }
+
+ buffer->buft = buft;
+ buffer->iface.init_tensor = ggml_backend_cpu_kleidiai_buffer_init_tensor;
+ buffer->iface.set_tensor = ggml_backend_cpu_kleidiai_buffer_set_tensor;
+ buffer->iface.get_tensor = nullptr;
+ buffer->iface.cpy_tensor = nullptr;
+ return buffer;
+}
+
+static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return TENSOR_ALIGNMENT;
+
+ GGML_UNUSED(buft);
+}
+
+namespace ggml::cpu::kleidiai {
+class extra_buffer_type : ggml::cpu::extra_buffer_type {
+ bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+ if ( op->op == GGML_OP_MUL_MAT &&
+ op->src[0]->type == GGML_TYPE_Q4_0 &&
+ op->src[0]->buffer &&
+ (ggml_n_dims(op->src[0]) == 2) &&
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels
+ ) {
+ if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
+ return false;
+ }
+ if (op->src[1]->type == GGML_TYPE_F32 &&
+ ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
+ if (op->op == GGML_OP_MUL_MAT) {
+ if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
+ return (ggml::cpu::tensor_traits *) op->src[0]->extra;
+ }
+ }
+ return nullptr;
+ }
+};
+} // namespace ggml::cpu::kleidiai
+
+ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
+ static ggml::cpu::kleidiai::extra_buffer_type ctx;
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_kleidiai = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_cpu_kleidiai_buffer_type_get_name,
+ /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
+ /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
+ /* .is_host = */ nullptr,
+ },
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
+ /* .context = */ &ctx,
+ };
+
+ init_kleidiai_context();
+
+ return &ggml_backend_cpu_buffer_type_kleidiai;
+}
--- /dev/null
+// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "ggml-alloc.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void);
+
+#ifdef __cplusplus
+}
+#endif