]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kleidiai: kernel interface refactoring (#16460)
authorCharles Xu <redacted>
Thu, 9 Oct 2025 07:29:17 +0000 (09:29 +0200)
committerGitHub <redacted>
Thu, 9 Oct 2025 07:29:17 +0000 (10:29 +0300)
ggml/src/ggml-cpu/kleidiai/kernels.cpp
ggml/src/ggml-cpu/kleidiai/kernels.h
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

index 7ba659124ca27c03d76bdf96a7455390b2ff3997..3eaa5e3f4100fcc81aa007ca01e99aa47e67aa38 100644 (file)
 
 #define NELEMS(x) sizeof(x) / sizeof(*x)
 
+template<size_t(*Fn)(size_t,size_t,size_t)>
+static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
+    return Fn(a, b, c);
+}
+
+template<size_t(*Fn)(size_t,size_t)>
+static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) {
+    return Fn(a, b);
+}
+
+template<void(*Fn)(size_t,size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
+static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl,
+                                     const void* lhs, const void* rhs, void* dst,
+                                     size_t dst_stride_row, size_t dst_stride_col,
+                                     float clamp_min, float clamp_max) {
+    Fn(m, n, k, bl, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
+}
+
+template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,void*,size_t,size_t,float,float)>
+static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
+                                   const void* lhs, const void* rhs, void* dst,
+                                   size_t dst_stride_row, size_t dst_stride_col,
+                                   float clamp_min, float clamp_max) {
+    Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
+}
+
+template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
+static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
+    return Fn(m, k, bl, mr, kr, sr);
+}
+
+template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
+static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
+    return Fn(m, k, mr, kr, sr);
+}
+
+template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
+static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
+    return Fn(m_idx, k, bl, mr, kr, sr);
+}
+
+template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
+static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
+    return Fn(m_idx, k, mr, kr, sr);
+}
+
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
+static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
+                                            size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
+    Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
+}
+
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
+static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
+                                           size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
+    Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
+}
+
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
+static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
+                                             size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
+    Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
+}
+
+template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
+static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
+    return Fn(n, k, nr, kr, bl);
+}
+
+template<size_t(*Fn)(size_t,size_t)>
+static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
+    return Fn(n, k);
+}
+
+template<size_t(*Fn)(size_t,size_t,size_t,size_t)>
+static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) {
+    return Fn(k, nr, kr, bl);
+}
+
+template<size_t(*Fn)(size_t)>
+static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
+    return Fn(k);
+}
+
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const uint8_t*,const float*,void*,size_t,const struct kai_rhs_pack_qs4cxs1s0_param*)>
+static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
+                                      size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/,
+                                      void* rhs_packed, size_t extra_bytes, const void* params) {
+    Fn(num_groups, n, k, nr, kr, sr, bl,
+       static_cast<const uint8_t*>(rhs),
+       static_cast<const float*>(bias),
+       rhs_packed, extra_bytes,
+       static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
+}
+
+template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
+static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
+                                               size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
+                                               void* rhs_packed, size_t extra_bytes, const void* params) {
+    Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params);
+}
+
 static const size_t INT4_PER_BYTE = 2;
 static const size_t INT4_BITS     = 4;
 static const int Q4_0_ZERO_POINT  = 8;
@@ -122,17 +224,18 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
         },
+
         /* .gemm_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
         },
         /* SME GEMV */
         /* .kern_info = */ {
@@ -142,23 +245,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
         },
         /* .gemv_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
         },
         /* .rhs_info = */ {
-            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
-            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
-            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
-            /* .to_float      = */ dequantize_row_qsi4c32ps1s0scalef16,
+            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .to_float              = */ dequantize_row_qsi4c32ps1s0scalef16,
+            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
+            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
         },
         /* .required_cpu       = */ CPU_FEATURE_SME,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -174,17 +278,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
+            /* .run_kernel_ex         = */ &kernel_run_fn10<kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
         },
         /* .gemm_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
-            /* .pack_func             = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
+            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
+            /* .pack_func_ex          = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
         },
         /* SME GEMV */
         /* .kern_info = */ {
@@ -194,23 +298,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
+            /* .get_lhs_offset_ex     = */ nullptr,
+            /* .get_rhs_packed_offset_ex = */ nullptr,
+            /* .run_kernel_ex         = */ nullptr,
         },
         /* .gemv_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
-            /* .pack_func             = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
+            /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
+            /* .pack_func_ex          = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
         },
         /* .rhs_info = */ {
-            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
-            /* .packed_stride = */ NULL,
-            /* .pack_func     = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
-            /* .to_float      = */ NULL,
+            /* .packed_stride         = */ nullptr,
+            /* .to_float              = */ nullptr,
+            /* .packed_size_ex        = */ &rhs_ps_fn2<kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn1<kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
+            /* .pack_func_ex          = */ &rhs_pack_fn13<kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
         },
         /* .required_cpu       = */ CPU_FEATURE_SME,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -229,17 +334,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
         },
         /* .gemm_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
         },
         /* DOTPROD GEMV */
         /* .kern_info = */ {
@@ -249,23 +354,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
         },
         /* .gemv_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
         },
         /* .rhs_info = */ {
-            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
+            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,
+            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -283,17 +389,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
         },
         /* .gemm_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
         },
         /* i8mm GEMV */
         /* .kern_info = */ {
@@ -303,23 +409,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
         },
         /* .gemv_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
         },
         /* .rhs_info = */ {
-            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
+            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,
+            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -338,17 +445,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
         },
         /* .gemm_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
         },
         /* i8mm GEMV */
         /* .kern_info = */ {
@@ -358,23 +465,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
         },
         /* .gemv_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
         },
         /* .rhs_info = */ {
-            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
+            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,
+            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -392,17 +500,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
         },
         /* .gemm_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
         },
         /* DOTPROD GEMV */
         /* .kern_info = */ {
@@ -412,23 +520,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
-            /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
-            /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
-            /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
+            /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
+            /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
+            /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
         },
         /* .gemv_lhs_info = */ {
             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
-            /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
-            /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
+            /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
+            /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
+            /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
         },
         /* .rhs_info = */ {
-            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
+            /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float              = */ dequantize_row_qsi4c32pscalef16,
+            /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
+            /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -443,6 +552,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
     ggml_kleidiai_kernels * kernel = nullptr;
 
     if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
+#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
         for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
             if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
                 gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
@@ -452,6 +562,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
                 break;
             }
         }
+#endif
     }
 
     return kernel;
@@ -460,12 +571,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
 ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
     ggml_kleidiai_kernels * kernels = nullptr;
 
+#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
     for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
         if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
             kernels = &gemm_gemv_kernels[i];
             break;
         }
     }
+#endif
 
     return kernels;
 }
index 2ad6ad6fd0bfc5b06db0514d2fb4bdccbc326e87..a84795a6b2e50841fc65ff6115e06dda93613642 100644 (file)
@@ -4,8 +4,6 @@
 
 #pragma once
 
-#include <functional>
-#include <variant>
 #include "ggml.h"
 
 enum cpu_feature {
@@ -15,6 +13,7 @@ enum cpu_feature {
     CPU_FEATURE_SVE     = 4,
     CPU_FEATURE_SME     = 8
 };
+
 inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
     lhs = static_cast<cpu_feature>(lhs | rhs);
     return lhs;
@@ -30,63 +29,52 @@ struct kernel_info {
     size_t (*get_nr)(void);
     size_t (*get_kr)(void);
     size_t (*get_sr)(void);
-    std::variant<
-        std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
-        std::function<size_t(size_t m_idx, size_t k)>
-    > get_lhs_offset;
-    std::variant<
-        std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
-        std::function<size_t(size_t n_idx, size_t k)>
-    > get_rhs_packed_offset;
+
     size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
     size_t (*get_dst_size)(size_t m, size_t n);
-    std::variant<
-        std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
-            float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
-        std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
-            size_t dst_stride_col, float clamp_min, float clamp_max)>
-    > run_kernel;
+
+    size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl);
+
+    size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl);
+
+    void (*run_kernel_ex)(
+        size_t m, size_t n, size_t k, size_t bl,
+        const void* lhs_packed, const void* rhs_packed,
+        void* dst, size_t dst_stride_row, size_t dst_stride_col,
+        float clamp_min, float clamp_max);
 };
 
 struct lhs_packing_info {
     size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
-    std::variant<
-        std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
-        std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
-    > get_packed_offset;
-    std::variant<
-        std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
-        std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
-    > packed_size;
-    std::variant<
-        std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
-            size_t lhs_stride, void* lhs_packed)>,
-        std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
-        void* lhs_packed)>
-    > pack_func;
+
+    size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
+
+    size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
+
+    void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
+        size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed);
 };
 
 struct rhs_packing_info {
-    std::variant<
-        std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
-        std::function<size_t(size_t n, size_t k)>
-    > packed_size;
     size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
-    std::variant<
-        std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
-            const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
-        std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
-            const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
-    > pack_func;
-    void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
-          size_t kr, size_t bl, size_t num_bytes_multiplier);
+
+    void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out,
+                     size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl,
+                     size_t num_bytes_multiplier);
+
+    size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
+
+    size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl);
+
+    void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
+        size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params);
 };
 
 struct ggml_kleidiai_kernels {
-    kernel_info gemm;
+    kernel_info      gemm;
     lhs_packing_info gemm_lhs_info;
 
-    kernel_info gemv;
+    kernel_info      gemv;
     lhs_packing_info gemv_lhs_info;
 
     rhs_packing_info rhs_info;
index 44691e5dfdf6a598873014a38a181ea6ca4b2a51..8b3df7d78009e527eb9714a8accd1bd9ef27f138 100644 (file)
@@ -8,6 +8,7 @@
 #include <stdexcept>
 #include <stdint.h>
 #include <string.h>
+#include <string>
 #if defined(__linux__)
 #include <asm/hwcap.h>
 #include <sys/auxv.h>
@@ -87,40 +88,6 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
     return tensor->ne[dim];
 }
 
-template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
-constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
-    using V = std::remove_reference_t<Variant>;
-    return (std::is_invocable_r_v<
-                Ret,
-                std::variant_alternative_t<Is, V>,
-                Args...> || ...);
-}
-
-template <typename Variant, typename Ret, typename... Args>
-constexpr bool variant_any_invocable_v =
-    variant_any_invocable_impl<Variant, Ret, Args...>(
-        std::make_index_sequence<
-            std::variant_size_v<std::remove_reference_t<Variant>>>{});
-
-template<typename Ret, typename Variant, typename... Args>
-static inline Ret variant_call(Variant && var, Args&&... args) {
-    static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
-                  "No alternative in Variant is invocable with the provided arguments and return type.");
-
-    return std::visit(
-        [&](auto && f) -> Ret {
-            using F = std::decay_t<decltype(f)>;
-            if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
-                return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
-            } else {
-                GGML_ABORT("Invalid function type in variant_call");
-                GGML_UNREACHABLE();
-            }
-        },
-        std::forward<Variant>(var)
-    );
-}
-
 namespace ggml::cpu::kleidiai {
 
 static size_t round_down(size_t x, size_t y) {
@@ -145,7 +112,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
             return false;
         }
         ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
-        GGML_ASSERT(kernels);
+        if (!kernels) {
+            return false;
+        }
         bool is_gemv = op->src[1]->ne[1] == 1;
         kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
         lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
@@ -159,16 +128,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         size_t sr = kernel->get_sr();
 
         if (kernels->rhs_type == GGML_TYPE_Q4_0) {
-            size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
+            if (!lhs_info->packed_size_ex) return false;
+            size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
         } else if (kernels->rhs_type == GGML_TYPE_F16) {
+            if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
             const int64_t lhs_batch_size0 = op->src[1]->ne[2];
             const int64_t rhs_batch_size0 = op->src[0]->ne[2];
             const int64_t r = lhs_batch_size0 / rhs_batch_size0;
-            size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
-                   variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
+            size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
+                   kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
                    k * n * sizeof(float) + n * sizeof(float);
         } else {
-            GGML_ASSERT(false);
+            return false;
         }
 
         return true;
@@ -196,12 +167,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         GGML_TENSOR_BINARY_OP_LOCALS
 
         ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
-        GGML_ASSERT(kernels);
+        if (!kernels) {
+            return false;
+        }
 
         const bool is_gemv = src1->ne[1] == 1;
         kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
         lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
         GGML_ASSERT(kernel);
+        if (!kernels->rhs_info.pack_func_ex ||
+            !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
+            return false;
+        }
 
         const int nth = params->nth;
         const int ith = params->ith;
@@ -228,10 +205,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         const int64_t kr = (int64_t) kernel->get_kr();
         const int64_t sr = (int64_t) kernel->get_sr();
 
-        const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
-        const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
-        const size_t kxn_size        = (size_t)k * (size_t)n * sizeof(float);
-        const size_t bias_size       = (size_t)n * sizeof(float);
+        const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
+        const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
+        const size_t kxn_size        = k * n * sizeof(float);
+        const size_t bias_size       = n * sizeof(float);
 
         const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
         GGML_ASSERT(wsize_required <= params->wsize);
@@ -259,10 +236,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
                     const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
 
                     // Base packed offset (aligned) and per-row stride in bytes
-                    const size_t base_packed_off = variant_call<size_t>(
-                        lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
-                    const size_t next_block_off = variant_call<size_t>(
-                        lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
+                    const size_t base_packed_off  = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
+                    const size_t next_block_off   = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
                     const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
 
                     int64_t remaining = m_count;
@@ -278,9 +253,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
                         const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
                         void * dst_ptr       = lhs_packed + dst_off;
 
-                        variant_call<void>(lhs_info->pack_func,
-                                        (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
-                                        /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
+                        lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
 
                         cur       += take;
                         remaining -= take;
@@ -296,10 +269,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
                                         reinterpret_cast<const uint16_t *>(rhs_batch_base),
                                         rhs_stride);
 
-                variant_call<void>(kernels->rhs_info.pack_func,
-                                   /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
-                                   /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
-                                   rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
+                kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
+                             rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
             }
 
             ggml_barrier(params->threadpool);
@@ -320,20 +291,15 @@ class tensor_traits : public ggml::cpu::tensor_traits {
                     const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
 
                     // LHS packed base at row 0 (consistent with packing above)
-                    const size_t lhs_packed_offset0 = variant_call<size_t>(
-                        lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
-                    const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
-                    const size_t dst_offset        = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
+                    const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
+                    const size_t rhs_packed_offset  = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
+                    const size_t dst_offset         = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
 
                     const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
                     const void * rhs_ptr = rhs_packed + rhs_packed_offset;
                     float * dst_ptr      = reinterpret_cast<float *>(dst_batch_base + dst_offset);
 
-                    variant_call<void>(kernel->run_kernel,
-                                       (size_t)m, (size_t)n_to_process, (size_t)k,
-                                       lhs_ptr, rhs_ptr,
-                                       dst_ptr, dst_stride, sizeof(float),
-                                       -FLT_MAX, FLT_MAX);
+                    kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
                 }
             }
 
@@ -354,13 +320,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         GGML_TENSOR_BINARY_OP_LOCALS
 
         ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
-        GGML_ASSERT(kernels);
+        if (!kernels) {
+            return false;
+        }
 
         bool is_gemv = src1->ne[1] == 1;
         kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
         lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
 
         GGML_ASSERT(kernel);
+        if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
+            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
+            return false;
+        }
 
         const int ith = params->ith;
         const int nth_raw = params->nth;
@@ -402,25 +374,26 @@ class tensor_traits : public ggml::cpu::tensor_traits {
             // Transform LHS
             const size_t src_stride        = src1->nb[1];
             const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
-            const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
+            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
             void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
 
-            variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
+            // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
+            lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
         }
 
         ggml_barrier(params->threadpool);
 
         // Perform the operation
         const size_t dst_stride        = dst->nb[1];
-        const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
-        const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
+        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
+        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
         const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
         const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
         const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
         float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
 
         if (n_to_process > 0) {
-            variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
+            kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
                                sizeof(float), -FLT_MAX, FLT_MAX);
         }
 
@@ -429,7 +402,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
 
     bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
         GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
-        GGML_ASSERT(ctx.kernels);
+        if (!ctx.kernels) {
+            return false;
+        }
 
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
@@ -438,6 +413,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
 
         rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
         kernel_info * kernel        = &ctx.kernels->gemm;
+        if (!rhs_info->to_float || !kernel->get_nr) {
+            return false;
+        }
 
         const int64_t nc     = ne00;
         const int64_t nr     = ggml_nelements(src1);
@@ -480,7 +458,7 @@ public:
         struct kai_rhs_pack_qs4cxs1s0_param params;
         params.lhs_zero_point = 1;
         params.rhs_zero_point = 8;
-        variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
+        ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, &params);
 
         return 0;
         GGML_UNUSED(data_size);
@@ -548,7 +526,7 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_
     const size_t nr = ctx.kernels->gemm.get_nr();
     const size_t kr = ctx.kernels->gemm.get_kr();
 
-    return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
+    return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
 
     GGML_UNUSED(buft);
 }