]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kleidiai : add sme fp16 compute path for q4_0 gemm on aarch64 (#20043)
authorCharles Xu <redacted>
Tue, 3 Mar 2026 09:40:26 +0000 (10:40 +0100)
committerGitHub <redacted>
Tue, 3 Mar 2026 09:40:26 +0000 (11:40 +0200)
ggml/src/ggml-cpu/CMakeLists.txt
ggml/src/ggml-cpu/kleidiai/kernels.cpp

index 3dc948e4d8e651a8042bfe54aa0174d014f71954..6ca3176a2f241f8f851f71b3e42de56a0e22706e 100644 (file)
@@ -566,9 +566,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
 
         # Fetch KleidiAI sources:
         include(FetchContent)
-        set(KLEIDIAI_COMMIT_TAG "v1.16.0")
+        set(KLEIDIAI_COMMIT_TAG "v1.22.0")
         set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
-        set(KLEIDIAI_ARCHIVE_MD5  "0a9e9008adb6031f9e8cf70dff4a3321")
+        set(KLEIDIAI_ARCHIVE_MD5  "54049037570ab0ee0a0d126b2ba5ece1")
 
         if (POLICY CMP0135)
             cmake_policy(SET CMP0135 NEW)
@@ -608,6 +608,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/
             ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
 
         set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
@@ -648,7 +649,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
 
         if (NOT SME_ENABLED MATCHES -1)
             list(APPEND GGML_KLEIDIAI_SOURCES
-                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
@@ -656,10 +656,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa_asm.S
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
                 ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
+                ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_f16pmrx2_f32_neon.c
                 ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
-            set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
+            set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2+sme2+fp16")
         endif()
 
         if (NOT SVE_ENABLED MATCHES -1)
index d114f2d49bfae0738084a83b77a0837ff64fd030..40f7c0df6505c6c03e19435effe76be39c972610 100644 (file)
@@ -1,4 +1,4 @@
-// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
 // SPDX-License-Identifier: MIT
 //
 
@@ -9,7 +9,6 @@
 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
 #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
-#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
 #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
 #include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
@@ -20,6 +19,7 @@
 #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
 #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
 #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
+#include "kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.h"
 
 #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
 #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
@@ -31,6 +31,7 @@
 #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
 #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
 #include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
+#include "kai_lhs_pack_f16pmrx2_f32_neon.h"
 
 #include "kai_common.h"
 
@@ -309,24 +310,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
     {
         /* SME GEMM */
         /* .kern_info = */ {
-            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
-            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
-            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
+            /* .get_m_step            = */ kai_get_m_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_n_step            = */ kai_get_n_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_mr                = */ kai_get_mr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,
         },
 
         /* .gemm_lhs_info = */ {
-            /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
-            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
-            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
-            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
+            /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_f16pmrx2_f32_neon,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_pack_f16pmrx2_f32_neon>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_pack_f16pmrx2_f32_neon>,
+            /* .pack_func_ex          = */ &lhs_pack_void_fn10<kai_run_lhs_pack_f16pmrx2_f32_neon>,
         },
         /* SME GEMV */
         /* .kern_info = */ {