#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
+#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
- const void* lhs, const void* rhs, void* dst,
- size_t dst_stride_row, size_t dst_stride_col,
- float clamp_min, float clamp_max) {
+ const void* lhs, const void* rhs, void* dst,
+ size_t dst_stride_row, size_t dst_stride_col,
+ float clamp_min, float clamp_max) {
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
- size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
- void* rhs_packed, size_t extra_bytes, const void* params) {
+ size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
+ void* rhs_packed, size_t extra_bytes, const void* params) {
Fn(num_groups, n, k, nr, kr, sr,
static_cast<const int8_t*>(rhs),
static_cast<const float*>(bias),
},
#endif
#else
+#if defined(__ARM_FEATURE_SVE)
+ {
+ /* SVE i8mm GEMM */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
+ /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
+ },
+ /* .gemm_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+ /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+ /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+ /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+ },
+ /* SVE dotprod GEMV */
+ /* .kern_info = */ {
+ /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+ /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+ /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+ /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+ /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+ /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+ /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+ /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
+ /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
+ /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
+ /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
+ },
+ /* .gemv_lhs_info = */ {
+ /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
+ /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
+ /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
+ /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
+ },
+ /* .rhs_info = */ {
+ /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+ /* .to_float = */ dequantize_row_qsi4c32pscalef16,
+ /* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+ /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+ /* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+ },
+ /* .required_cpu = */ CPU_FEATURE_SVE | CPU_FEATURE_I8MM | CPU_FEATURE_DOTPROD,
+ /* .lhs_type = */ GGML_TYPE_F32,
+ /* .rhs_type = */ GGML_TYPE_Q4_0,
+ /* .op_type = */ GGML_TYPE_F32,
+ },
+#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
{
/* i8mm GEMM */
/* .rhs_type = */ GGML_TYPE_Q4_0,
/* .op_type = */ GGML_TYPE_F32,
},
-#endif
+#endif // __ARM_FEATURE_MATMUL_INT8
#if defined(__ARM_FEATURE_DOTPROD)
{
/* DOTPROD GEMM */
ggml_kleidiai_kernels * kernel = nullptr;
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
-#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
- for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
- if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
- gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
- gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
- gemm_gemv_kernels[i].op_type == tensor->type) {
- kernel = &gemm_gemv_kernels[i];
- break;
- }
- }
- if (!kernel) {
- for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
- if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
- gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
- gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
- gemm_gemv_kernels_q8[i].op_type == tensor->type) {
- kernel = &gemm_gemv_kernels_q8[i];
- break;
+#if defined(__ARM_FEATURE_SME) || \
+ defined(__ARM_FEATURE_DOTPROD) || \
+ defined(__ARM_FEATURE_MATMUL_INT8) || \
+ defined(__ARM_FEATURE_SVE)
+ auto try_table = [&](auto & table) {
+ for (size_t i = 0; i < NELEMS(table) - 1; ++i) {
+ if ((cpu_features & table[i].required_cpu) == table[i].required_cpu &&
+ table[i].lhs_type == tensor->src[1]->type &&
+ table[i].rhs_type == tensor->src[0]->type &&
+ table[i].op_type == tensor->type) {
+ kernel = &table[i];
+ return true;
}
}
+ return false;
+ };
+
+ if (tensor->src[0]->type == GGML_TYPE_Q8_0) {
+ try_table(gemm_gemv_kernels_q8);
+ } else {
+ try_table(gemm_gemv_kernels);
}
#else
GGML_UNUSED(gemm_gemv_kernels);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
ggml_kleidiai_kernels * kernels = nullptr;
-#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
+#if defined(__ARM_FEATURE_SME) || \
+ defined(__ARM_FEATURE_DOTPROD) || \
+ defined(__ARM_FEATURE_MATMUL_INT8) || \
+ defined(__ARM_FEATURE_SVE)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
kernels = &gemm_gemv_kernels[i];