]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
kleidiai: add and integrate SVE 256-bit vector-length kernel (llama/18458)
authorCharles Xu <redacted>
Tue, 30 Dec 2025 12:04:53 +0000 (13:04 +0100)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 10:39:43 +0000 (12:39 +0200)
* kleidiai: add and integrate SVE 256-bit vector-length kernel

* updated for review comments

src/CMakeLists.txt
src/ggml-cpu/CMakeLists.txt
src/ggml-cpu/kleidiai/kernels.cpp
src/ggml-cpu/kleidiai/kleidiai.cpp

index 25f25c423643171016fc8097daba623892bce6c4..6192a87046688598df35c866a5ae30016bdd14a8 100644 (file)
@@ -401,8 +401,8 @@ if (GGML_CPU_ALL_VARIANTS)
             ggml_add_cpu_backend_variant(android_armv8.2_2    DOTPROD FP16_VECTOR_ARITHMETIC)
             ggml_add_cpu_backend_variant(android_armv8.6_1    DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
             ggml_add_cpu_backend_variant(android_armv9.0_1    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
-            ggml_add_cpu_backend_variant(android_armv9.2_1    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SME)
-            ggml_add_cpu_backend_variant(android_armv9.2_2    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
+            ggml_add_cpu_backend_variant(android_armv9.2_1    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
+            ggml_add_cpu_backend_variant(android_armv9.2_2    DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME)
         elseif (APPLE)
             ggml_add_cpu_backend_variant(apple_m1             DOTPROD)
             ggml_add_cpu_backend_variant(apple_m2_m3          DOTPROD MATMUL_INT8)
index 28fb7612e5725586d291568713988e4b6b04cc7f..7622d0bf49bb0f329d5d1724f4f7eb9bae7ee40a 100644 (file)
@@ -561,9 +561,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
 
         # Fetch KleidiAI sources:
         include(FetchContent)
-        set(KLEIDIAI_COMMIT_TAG "v1.14.0")
+        set(KLEIDIAI_COMMIT_TAG "v1.16.0")
         set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
-        set(KLEIDIAI_ARCHIVE_MD5  "45e110675d93f99f82c23a1afcca76bc")
+        set(KLEIDIAI_ARCHIVE_MD5  "0a9e9008adb6031f9e8cf70dff4a3321")
 
         if (POLICY CMP0135)
             cmake_policy(SET CMP0135 NEW)
@@ -615,6 +615,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
         string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
         string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
+        string(FIND "${ARCH_FLAGS_TEMP}" "+sve" SVE_ENABLED)
 
         set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
 
@@ -659,6 +660,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
         endif()
 
+        if (NOT SVE_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES
+                ${KLEIDIAI_SRC}/kai/kai_common_sve_asm.S
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod_asm.S
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm_asm.S
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.c)
+        endif()
+
         set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
         list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
     endif()
index 55a00f008ac03c419bb4d9429e9161291180e824..d114f2d49bfae0738084a83b77a0837ff64fd030 100644 (file)
@@ -18,6 +18,8 @@
 #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
+#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
 
 #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
 #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
@@ -69,9 +71,9 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
 
 template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
 static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
-                                     const void* lhs, const void* rhs, void* dst,
-                                     size_t dst_stride_row, size_t dst_stride_col,
-                                     float clamp_min, float clamp_max) {
+                                         const void* lhs, const void* rhs, void* dst,
+                                         size_t dst_stride_row, size_t dst_stride_col,
+                                         float clamp_min, float clamp_max) {
     Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
 }
 
@@ -152,8 +154,8 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
 
 template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
 static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
-                                               size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
-                                               void* rhs_packed, size_t extra_bytes, const void* params) {
+                                       size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
+                                       void* rhs_packed, size_t extra_bytes, const void* params) {
     Fn(num_groups, n, k, nr, kr, sr,
        static_cast<const int8_t*>(rhs),
        static_cast<const float*>(bias),
@@ -524,6 +526,61 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
     },
 #endif
 #else
+#if defined(__ARM_FEATURE_SVE)
+    {
+        /* SVE i8mm GEMM */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
+        },
+        /* .gemm_lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+        },
+        /* SVE dotprod GEMV */
+        /* .kern_info = */ {
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
+        },
+        /* .gemv_lhs_info = */ {
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
+        },
+        /* .rhs_info = */ {
+            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,
+            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+        },
+        /* .required_cpu       = */ CPU_FEATURE_SVE | CPU_FEATURE_I8MM | CPU_FEATURE_DOTPROD,
+        /* .lhs_type           = */ GGML_TYPE_F32,
+        /* .rhs_type           = */ GGML_TYPE_Q4_0,
+        /* .op_type            = */ GGML_TYPE_F32,
+    },
+#endif
 #if defined(__ARM_FEATURE_MATMUL_INT8)
     {
         /* i8mm GEMM */
@@ -578,7 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
         /* .rhs_type           = */ GGML_TYPE_Q4_0,
         /* .op_type            = */ GGML_TYPE_F32,
     },
-#endif
+#endif // __ARM_FEATURE_MATMUL_INT8
 #if defined(__ARM_FEATURE_DOTPROD)
     {
         /* DOTPROD GEMM */
@@ -811,26 +868,27 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
     ggml_kleidiai_kernels * kernel = nullptr;
 
     if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
-#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
-        for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
-            if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
-                gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
-                gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
-                gemm_gemv_kernels[i].op_type  == tensor->type) {
-                kernel = &gemm_gemv_kernels[i];
-                break;
-            }
-        }
-        if (!kernel) {
-            for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
-                if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
-                    gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
-                    gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
-                    gemm_gemv_kernels_q8[i].op_type  == tensor->type) {
-                    kernel = &gemm_gemv_kernels_q8[i];
-                    break;
+#if defined(__ARM_FEATURE_SME)          ||  \
+    defined(__ARM_FEATURE_DOTPROD)      ||  \
+    defined(__ARM_FEATURE_MATMUL_INT8)  ||  \
+    defined(__ARM_FEATURE_SVE)
+        auto try_table = [&](auto & table) {
+            for (size_t i = 0; i < NELEMS(table) - 1; ++i) {
+                if ((cpu_features & table[i].required_cpu) == table[i].required_cpu &&
+                    table[i].lhs_type == tensor->src[1]->type &&
+                    table[i].rhs_type == tensor->src[0]->type &&
+                    table[i].op_type  == tensor->type) {
+                    kernel = &table[i];
+                    return true;
                 }
             }
+            return false;
+        };
+
+        if (tensor->src[0]->type == GGML_TYPE_Q8_0) {
+            try_table(gemm_gemv_kernels_q8);
+        } else {
+            try_table(gemm_gemv_kernels);
         }
 #else
     GGML_UNUSED(gemm_gemv_kernels);
@@ -845,7 +903,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
 ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
     ggml_kleidiai_kernels * kernels = nullptr;
 
-#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
+#if defined(__ARM_FEATURE_SME)          ||  \
+    defined(__ARM_FEATURE_DOTPROD)      ||  \
+    defined(__ARM_FEATURE_MATMUL_INT8)  ||  \
+    defined(__ARM_FEATURE_SVE)
     for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
         if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
             kernels = &gemm_gemv_kernels[i];
index 6f2a90fbda7bdd31c4d87cb8634c64504bb8cd0f..ad23e73184e5e6255eb6b3f67fc87340e121bff6 100644 (file)
@@ -46,13 +46,20 @@ struct ggml_kleidiai_context {
 } static ctx = { CPU_FEATURE_NONE, NULL, NULL };
 
 static const char* cpu_feature_to_string(cpu_feature f) {
-    switch (f) {
-        case CPU_FEATURE_NONE:    return "NONE";
-        case CPU_FEATURE_DOTPROD: return "DOTPROD";
-        case CPU_FEATURE_I8MM:    return "I8MM";
-        case CPU_FEATURE_SVE:     return "SVE";
-        case CPU_FEATURE_SME:     return "SME";
-        default:                  return "UNKNOWN";
+    if (f == CPU_FEATURE_NONE) {
+        return "NONE";
+    } else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+        return "SME";
+    } else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {
+        return "SVE";
+    }
+    else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {
+        return "I8MM";
+    } else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {
+        return "DOTPROD";
+    }
+    else {
+        return "UNKNOWN";
     }
 }
 
@@ -68,7 +75,7 @@ static void init_kleidiai_context(void) {
 
         ctx.features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
                         (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |
-                        (ggml_cpu_has_sve()         ? CPU_FEATURE_SVE     : CPU_FEATURE_NONE);
+                        ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
 
         if (env_var) {
             sme_enabled = atoi(env_var);