// KleidiAI micro-kernels
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
+#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
+#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
+#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
+#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
+#include "kai_lhs_quant_pack_qai8dxp_f32.h"
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
+#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
#include "kai_common.h"
#include "simd-mappings.h"
+#define GGML_COMMON_DECL_CPP
+#include "ggml-common.h"
+
#include "kernels.h"
#define NELEMS(x) sizeof(x) / sizeof(*x)
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
+template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
+static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
+ const void* lhs, const void* rhs, void* dst,
+ size_t dst_stride_row, size_t dst_stride_col,
+ float clamp_min, float clamp_max) {
+ Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
+}
+
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
return Fn(m, k, bl, mr, kr, sr);
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
}
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
+static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
+ size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
+ Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
+}
+
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
return Fn(n, k, nr, kr, bl);
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
}
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
+static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
+ size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
+ void* rhs_packed, size_t extra_bytes, const void* params) {
+ Fn(num_groups, n, k, nr, kr, sr,
+ static_cast<const int8_t*>(rhs),
+ static_cast<const float*>(bias),
+ static_cast<const float*>(scale),
+ rhs_packed, extra_bytes,
+ static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
+}
+
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
GGML_UNUSED(kr);
}
+static void dequantize_row_qsi8cxp(
+ const void *packed_data,
+ int32_t row_idx,
+ int64_t k,
+ float *out,
+ size_t nr,
+ size_t packed_row_stride,
+ size_t kr,
+ size_t bl,
+ size_t num_bytes_multiplier
+) {
+ GGML_UNUSED(bl);
+ GGML_UNUSED(num_bytes_multiplier);
+
+ const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
+ const size_t group_idx = row_idx / nr;
+ const size_t row_in_group = row_idx % nr;
+
+ const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
+ const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
+
+ const size_t num_blocks = k_internal / kr;
+
+ for (size_t block = 0; block < num_blocks; ++block) {
+ const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
+ for (size_t i = 0; i < kr; ++i) {
+ const size_t k_idx = block * kr + i;
+ if (k_idx < (size_t) k) {
+ out[k_idx] = static_cast<float>(block_ptr[i]);
+ }
+ }
+ }
+
+ const uint8_t * sums_ptr = group_ptr + nr * k_internal;
+ GGML_UNUSED(sums_ptr);
+
+ const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
+ const float scale = scale_ptr[row_in_group];
+
+ if (scale == 0.0f) {
+ for (size_t i = 0; i < (size_t) k; ++i) {
+ out[i] = 0.0f;
+ }
+ return;
+ }
+
+ for (size_t i = 0; i < (size_t) k; ++i) {
+ out[i] *= scale;
+ }
+}
+
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#if defined(__ARM_FEATURE_SME)
{
#endif
};
+static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
+#if defined(__ARM_FEATURE_SME)
+ {
+ /* SME GEMM */
+ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
+ },
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+ },
+ /* SME GEMV */
+ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
+ },
+ /* .gemv_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+ },
+ /* .rhs_info = */ {
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
+ /* .to_float = */ dequantize_row_qsi8cxp,
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+ /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+ },
+ /* .required_cpu = */ CPU_FEATURE_SME,
+ /* .lhs_type = */ GGML_TYPE_F32,
+ /* .rhs_type = */ GGML_TYPE_Q8_0,
+ /* .op_type = */ GGML_TYPE_F32,
+ },
+#endif
+#if defined(__ARM_FEATURE_MATMUL_INT8)
+ {
+ /* I8MM GEMM */
+ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
+ },
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+ },
+ /* I8MM GEMV (dotprod fallback) */
+ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
+ },
+ /* .gemv_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+ },
+ /* .rhs_info = */ {
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
+ /* .to_float = */ dequantize_row_qsi8cxp,
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+ /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+ },
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+ /* .lhs_type = */ GGML_TYPE_F32,
+ /* .rhs_type = */ GGML_TYPE_Q8_0,
+ /* .op_type = */ GGML_TYPE_F32,
+ },
+#endif
+#if defined(__ARM_FEATURE_DOTPROD)
+ {
+ /* DOTPROD GEMM */
+ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
+ },
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+ },
+ /* DOTPROD GEMV */
+ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
+ /* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
+ },
+ /* .gemv_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
+ /* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
+ /* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
+ /* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
+ },
+ /* .rhs_info = */ {
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
+ /* .to_float = */ dequantize_row_qsi8cxp,
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+ /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
+ },
+ /* .required_cpu = */ CPU_FEATURE_DOTPROD,
+ /* .lhs_type = */ GGML_TYPE_F32,
+ /* .rhs_type = */ GGML_TYPE_Q8_0,
+ /* .op_type = */ GGML_TYPE_F32,
+ },
+#endif
+};
+
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
ggml_kleidiai_kernels * kernel = nullptr;
break;
}
}
+ if (!kernel) {
+ for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
+ if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
+ gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
+ gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
+ gemm_gemv_kernels_q8[i].op_type == tensor->type) {
+ kernel = &gemm_gemv_kernels_q8[i];
+ break;
+ }
+ }
+ }
#endif
}
return kernels;
}
+
+ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
+ ggml_kleidiai_kernels * kernels = nullptr;
+
+#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
+ for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
+ if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
+ kernels = &gemm_gemv_kernels_q8[i];
+ break;
+ }
+ }
+#endif
+
+ return kernels;
+}
#include <assert.h>
#include <atomic>
#include <cfloat>
+#include <cmath>
+#include <algorithm>
#include <stdexcept>
#include <stdint.h>
#include <string.h>
#include <string>
+#include <vector>
#if defined(__linux__)
#include <asm/hwcap.h>
#include <sys/auxv.h>
struct ggml_kleidiai_context {
cpu_feature features;
- ggml_kleidiai_kernels * kernels;
-} static ctx = { CPU_FEATURE_NONE, NULL };
+ ggml_kleidiai_kernels * kernels_q4;
+ ggml_kleidiai_kernels * kernels_q8;
+} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
static const char* cpu_feature_to_string(cpu_feature f) {
switch (f) {
if (sme_enabled != 0) {
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
}
- ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
+ ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
+ ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
#ifndef NDEBUG
- if (ctx.kernels) {
- GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
+ if (ctx.kernels_q4) {
+ GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
+ }
+ if (ctx.kernels_q8) {
+ GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
}
#endif
}
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
if (!lhs_info->packed_size_ex) return false;
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
+ } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
+ if (!lhs_info->packed_size_ex) return false;
+ size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
if (dst->op == GGML_OP_MUL_MAT) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
return compute_forward_q4_0(params, dst);
+ } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
+ return compute_forward_q8_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_fp16(params, dst);
}
} else if (dst->op == GGML_OP_GET_ROWS) {
- if (dst->src[0]->type == GGML_TYPE_Q4_0) {
+ if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
return compute_forward_get_rows(params, dst);
}
}
return true;
}
- bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
- if (!ctx.kernels) {
+ bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
+ if (!kernels) {
return false;
}
+ bool is_gemv = src1->ne[1] == 1;
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+
+ if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
+ !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
+ return false;
+ }
+
+ const int ith = params->ith;
+ const int nth_raw = params->nth;
+ const int nth = nth_raw > 0 ? nth_raw : 1;
+
+ const size_t k = ne00;
+ const size_t m = ne11;
+ const size_t n = ne01;
+
+ size_t mr = kernel->get_mr();
+ size_t kr = kernel->get_kr();
+ size_t sr = kernel->get_sr();
+
+ const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
+ uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
+ const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+
+ const size_t n_step = kernel->get_n_step();
+ const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
+ const size_t n_start = ith * num_n_per_thread;
+
+ size_t n_to_process = 0;
+ if (n_start < n) {
+ n_to_process = num_n_per_thread;
+ if ((n_start + n_to_process) > n) {
+ n_to_process = n - n_start;
+ }
+ }
+
+ const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
+ const size_t m_start = ith * num_m_per_thread;
+ size_t m_to_process = num_m_per_thread;
+ if ((m_start + m_to_process) > m) {
+ m_to_process = m - m_start;
+ }
+
+ if (m_start < m) {
+ const size_t src_stride = src1->nb[1];
+ const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
+ const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
+ void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
+
+ lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
+ }
+
+ ggml_barrier(params->threadpool);
+
+ const size_t dst_stride = dst->nb[1];
+ const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
+ const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
+ const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
+ const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
+ const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
+ float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+
+ if (n_to_process > 0) {
+ kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
+ sizeof(float), -FLT_MAX, FLT_MAX);
+ }
+
+ return true;
+ }
+
+ bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
- rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
- kernel_info * kernel = &ctx.kernels->gemm;
+ ggml_kleidiai_kernels * kernels = nullptr;
+ size_t block_len = 0;
+ size_t num_bytes_multiplier = 0;
+
+ if (dst->src[0]->type == GGML_TYPE_Q4_0) {
+ if (!ctx.kernels_q4) {
+ return false;
+ }
+ kernels = ctx.kernels_q4;
+ block_len = QK4_0;
+ num_bytes_multiplier = sizeof(uint16_t);
+ } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
+ if (!ctx.kernels_q8) {
+ return false;
+ }
+ kernels = ctx.kernels_q8;
+ block_len = QK8_0;
+ num_bytes_multiplier = sizeof(float);
+ } else {
+ return false;
+ }
+
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
+ kernel_info * kernel = &kernels->gemm;
if (!rhs_info->to_float || !kernel->get_nr) {
return false;
}
const size_t block_rows = kernel->get_nr();
const size_t kr = kernel->get_kr();
- const size_t num_bytes_multiplier = sizeof(uint16_t);
- const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
+ const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
const int ith = params->ith;
const int nth = params->nth;
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
float *out = (float *)((char *)dst->data + i * nb1);
- rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
+ rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
}
return true;
public:
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
- GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
- GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
- size_t nr = ctx.kernels->gemm.get_nr();
- size_t kr = ctx.kernels->gemm.get_kr();
- size_t sr = ctx.kernels->gemm.get_sr();
- struct kai_rhs_pack_qs4cxs1s0_param params;
- params.lhs_zero_point = 1;
- params.rhs_zero_point = 8;
- ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms);
+ if (tensor->type == GGML_TYPE_Q4_0) {
+ if (!ctx.kernels_q4) {
+ return -1;
+ }
+ size_t nr = ctx.kernels_q4->gemm.get_nr();
+ size_t kr = ctx.kernels_q4->gemm.get_kr();
+ size_t sr = ctx.kernels_q4->gemm.get_sr();
+
+ struct kai_rhs_pack_qs4cxs1s0_param params;
+ params.lhs_zero_point = 1;
+ params.rhs_zero_point = 8;
+ ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
+ static_cast<const uint8_t *>(data),
+ nullptr, nullptr, tensor->data, 0, ¶ms);
+ GGML_UNUSED(data_size);
+ return 0;
+ } else if (tensor->type == GGML_TYPE_Q8_0) {
+ if (!ctx.kernels_q8) {
+ return -1;
+ }
+
+ const size_t row_stride = tensor->nb[1];
+ const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
+
+ std::vector<int8_t> qdata(n * k, 0);
+ std::vector<float> scales(n, 0.0f);
+
+ for (size_t row = 0; row < n; ++row) {
+ const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
+ static_cast<const uint8_t *>(data) + row * row_stride);
+
+ float max_abs = 0.0f;
+ for (size_t block = 0; block < k_blocks; ++block) {
+ const block_q8_0 & blk = row_blocks[block];
+ const float d = GGML_FP16_TO_FP32(blk.d);
+ for (size_t l = 0; l < QK8_0; ++l) {
+ const size_t linear_idx = block * QK8_0 + l;
+ if (linear_idx >= k) {
+ break;
+ }
+ const float value = d * blk.qs[l];
+ max_abs = std::max(max_abs, std::fabs(value));
+ }
+ }
+
+ float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
+ scales[row] = scale;
+ const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
+
+ for (size_t block = 0; block < k_blocks; ++block) {
+ const block_q8_0 & blk = row_blocks[block];
+ const float d = GGML_FP16_TO_FP32(blk.d);
+ for (size_t l = 0; l < QK8_0; ++l) {
+ const size_t linear_idx = block * QK8_0 + l;
+ if (linear_idx >= k) {
+ break;
+ }
+ const float value = d * blk.qs[l];
+ int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
+ q = std::clamp(q, -127, 127);
+ qdata[row * k + linear_idx] = static_cast<int8_t>(q);
+ }
+ }
+ }
+
+ size_t nr = ctx.kernels_q8->gemm.get_nr();
+ size_t kr = ctx.kernels_q8->gemm.get_kr();
+ size_t sr = ctx.kernels_q8->gemm.get_sr();
+
+ struct kai_rhs_pack_qsi8cx_params params;
+ params.lhs_zero_point = 1;
+ params.scale_multiplier = 1.0f;
+
+ ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
+ qdata.data(), nullptr, scales.data(),
+ tensor->data, 0, ¶ms);
+ GGML_UNUSED(data_size);
+ return 0;
+ }
- return 0;
GGML_UNUSED(data_size);
+ return -1;
}
};
}
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
- GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
- GGML_ASSERT(ctx.kernels);
+ GGML_UNUSED(buft);
- const size_t n = tensor->ne[1];
- const size_t k = tensor->ne[0];
- const size_t nr = ctx.kernels->gemm.get_nr();
- const size_t kr = ctx.kernels->gemm.get_kr();
+ const size_t n = tensor->ne[1];
+ const size_t k = tensor->ne[0];
+
+ ggml_kleidiai_kernels * kernels = nullptr;
+ size_t block_len = 0;
+
+ if (tensor->type == GGML_TYPE_Q4_0) {
+ GGML_ASSERT(ctx.kernels_q4);
+ kernels = ctx.kernels_q4;
+ block_len = QK4_0;
+ } else if (tensor->type == GGML_TYPE_Q8_0) {
+ GGML_ASSERT(ctx.kernels_q8);
+ kernels = ctx.kernels_q8;
+ block_len = QK8_0;
+ } else {
+ return 0;
+ }
- return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
+ const size_t nr = kernels->gemm.get_nr();
+ const size_t kr = kernels->gemm.get_kr();
+ const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
+ const size_t raw = ggml_nbytes(tensor);
- GGML_UNUSED(buft);
+ return packed > raw ? packed : raw;
}
namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
- op->src[0]->type == GGML_TYPE_Q4_0 &&
+ (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
- op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
+ if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
+ return false;
+ }
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}