]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kleidiai : support for concurrent sme and neon kernel execution (#20070)
authorCharles Xu <redacted>
Tue, 10 Mar 2026 07:25:25 +0000 (08:25 +0100)
committerGitHub <redacted>
Tue, 10 Mar 2026 07:25:25 +0000 (09:25 +0200)
docs/build.md
ggml/src/ggml-cpu/kleidiai/kernels.cpp
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

index 772731f6418472c926bfe3c419baf5faa984b967..0717a799ae392ff6ac043d984be9d9ed68c4f441 100644 (file)
@@ -599,7 +599,13 @@ If KleidiAI is enabled, the output will contain a line similar to:
 ```
 load_tensors: CPU_KLEIDIAI model buffer size =  3474.00 MiB
 ```
-KleidiAI's microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm and SME. llama.cpp selects the most efficient kernel based on runtime CPU feature detection. However, on platforms that support SME, you must manually enable SME microkernels by setting the environment variable `GGML_KLEIDIAI_SME=1`.
+KleidiAI’s microkernels implement optimized tensor operations using Arm CPU features such as dotprod, int8mm, SVE, and SME. Llama.cpp selects the most efficient kernels at runtime based on detected CPU capabilities.
+On CPUs that support SME, SME microkernels are enabled automatically using runtime detection.
+The environment variable GGML_KLEIDIAI_SME can be used to control SME behavior:
+- Not set: enable SME automatically if supported and detected.
+- 0: disable SME.
+- <n> > 0: enable SME and assume <n> available SME units (override auto detection).
+If SME is not supported by the CPU, SME microkernels are always disabled.
 
 Depending on your build target, other higher priority backends may be enabled by default. To ensure the CPU backend is used, you must disable the higher priority backends either at compile time, e.g. -DGGML_METAL=OFF, or during run-time using the command line option `--device none`.
 
index 40f7c0df6505c6c03e19435effe76be39c972610..8c4d7bc925f6ab80ecf7e376ab866504ed2b2c53 100644 (file)
@@ -520,7 +520,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
             /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
         },
-        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .required_cpu       = */ CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
         /* .rhs_type           = */ GGML_TYPE_Q4_0,
         /* .op_type            = */ GGML_TYPE_F32,
@@ -631,7 +631,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
             /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
         },
-        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .required_cpu       = */ CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
         /* .rhs_type           = */ GGML_TYPE_Q4_0,
         /* .op_type            = */ GGML_TYPE_F32,
@@ -801,7 +801,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
             /* .pack_func_ex          = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
         },
-        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .required_cpu       = */ CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
         /* .rhs_type           = */ GGML_TYPE_Q8_0,
         /* .op_type            = */ GGML_TYPE_F32,
index ad23e73184e5e6255eb6b3f67fc87340e121bff6..9bcc18d442c1fc055318b3e9d8bae3f8aa901d93 100644 (file)
@@ -1,20 +1,31 @@
-// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
 // SPDX-License-Identifier: MIT
 //
 #include <arm_neon.h>
 #include <assert.h>
+#include <stdio.h>
 #include <atomic>
 #include <cfloat>
-#include <cmath>
 #include <algorithm>
+#include <cmath>
 #include <stdexcept>
 #include <stdint.h>
 #include <string.h>
 #include <string>
 #include <vector>
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <fstream>
+#include <set>
+#include <iostream>
+#include <climits>
 #if defined(__linux__)
 #include <asm/hwcap.h>
 #include <sys/auxv.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
 #elif defined(__APPLE__)
 #include <string_view>
 #include <sys/sysctl.h>
 #define GGML_COMMON_DECL_CPP
 #include "ggml-common.h"
 
+static constexpr int      GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
+static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC       = 0x4b4c4149; // "KLAI"
+static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION     = 1;
+static constexpr size_t   GGML_KLEIDIAI_PACK_ALIGN       = 64;
+
 struct ggml_kleidiai_context {
     cpu_feature features;
     ggml_kleidiai_kernels * kernels_q4;
     ggml_kleidiai_kernels * kernels_q8;
-} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
+    int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
+    int thread_hint;    // <= 0 means “no hint”
+} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
 
 static const char* cpu_feature_to_string(cpu_feature f) {
     if (f == CPU_FEATURE_NONE) {
@@ -63,41 +81,335 @@ static const char* cpu_feature_to_string(cpu_feature f) {
     }
 }
 
-static void init_kleidiai_context(void) {
+static size_t detect_num_smcus() {
+    if (!ggml_cpu_has_sme()) {
+        return 0;
+    }
+
+#if defined(__linux__) && defined(__aarch64__)
+    // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
+    size_t num_private = 0;
+    std::set<uint32_t> shared_ids;
+
+    for (size_t cpu = 0;; ++cpu) {
+        const std::string path =
+            "/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
+            "/regs/identification/smidr_el1";
+
+        std::ifstream file(path);
+        if (!file.is_open()) {
+            break;
+        }
+
+        uint64_t smidr = 0;
+        if (!(file >> std::hex >> smidr)) {
+            continue;
+        }
+
+        // Arm ARM: SMIDR_EL1
+        const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
+        // Build an "affinity-like" identifier for shared SMCUs.
+        // Keep the original packing logic, but isolate it here.
+        const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
+
+        switch (sh) {
+            case 0b10: // private SMCU
+                ++num_private;
+                break;
+            case 0b11: // shared SMCU
+                shared_ids.emplace(id);
+                break;
+            case 0b00:
+                // Ambiguous / implementation-defined. Be conservative:
+                // treat id==0 as private, otherwise as shared.
+                if (id == 0) ++num_private;
+                else shared_ids.emplace(id);
+                break;
+            default:
+                break;
+        }
+    }
+
+    return num_private + shared_ids.size();
+
+#elif defined(__APPLE__) && defined(__aarch64__)
+    // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.
+    char chip_name[256] = {};
+    size_t size = sizeof(chip_name);
+
+    if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
+        const std::string brand(chip_name);
+
+        struct ModelSMCU { const char *match; size_t smcus; };
+        static const ModelSMCU table[] = {
+            { "M4 Ultra", 2 },
+            { "M4 Max",   2 },
+            { "M4 Pro",   2 },
+            { "M4",       1 },
+        };
 
+        for (const auto &e : table) {
+            if (brand.find(e.match) != std::string::npos) {
+                return e.smcus;
+            }
+        }
+    }
+    return 1;
+
+#else
+    return 1;
+#endif
+}
+
+static int parse_uint_env(const char *s, const char *name, bool *ok) {
+    if (!s) { *ok = false; return 0; }
+    char *end = nullptr;
+    long v = strtol(s, &end, 10);
+    if (end == s || *end != '\0') {
+        GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
+        *ok = false;
+        return 0;
+    }
+    if (v < 0 || v > INT_MAX) {
+        GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
+        *ok = false;
+        return 0;
+    }
+    *ok = true;
+    return (int)v;
+}
+
+static void init_kleidiai_context(void) {
     ggml_critical_section_start();
     static bool initialized = false;
 
     if (!initialized) {
         initialized = true;
-        const char *env_var = getenv("GGML_KLEIDIAI_SME");
-        int sme_enabled = 0;
+
+        const char *env_sme     = getenv("GGML_KLEIDIAI_SME");
+        const char *env_threads = getenv("GGML_TOTAL_THREADS");
+
+        const bool cpu_has_sme = ggml_cpu_has_sme();
+        size_t detected_smcus = 0;
 
         ctx.features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
                         (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |
                         ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
 
-        if (env_var) {
-            sme_enabled = atoi(env_var);
+        if (env_threads) {
+            bool ok = false;
+            int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
+            if (ok && hint > 0) {
+                ctx.thread_hint = hint;
+            }
         }
 
-        if (sme_enabled != 0) {
-            ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
+        // SME policy:
+        // - If CPU doesn't support SME: SME always off.
+        // - Else:
+        //   - env unset => auto-detect cores; enable if detected > 0.
+        //   - env=0     => force off.
+        //   - env>0     => force N cores (skip detection).
+        int sme_cores = 0;
+        bool sme_env_ok = false;
+        bool sme_env_set = (env_sme != nullptr);
+
+        if (!cpu_has_sme) {
+            if (sme_env_set) {
+                bool ok = false;
+                int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
+                if (ok && req > 0) {
+                    GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
+                }
+            }
+            sme_cores = 0;
+        } else {
+            if (sme_env_set) {
+                bool ok = false;
+                int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
+                sme_env_ok = ok;
+
+                if (!ok) {
+                    GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
+                    detected_smcus = detect_num_smcus();
+                    sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
+                } else if (v == 0) {
+                    sme_cores = 0;
+                } else {
+                    sme_cores = v;
+                }
+            } else {
+                detected_smcus = detect_num_smcus();
+                sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
+            }
+
+            if (!sme_env_set && sme_cores == 0) {
+                GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
+            }
+
+            if (sme_cores > 0) {
+                ctx.features |= CPU_FEATURE_SME;
+            }
         }
+
+        // Kernel selection
         ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
         ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
-#ifndef NDEBUG
-        if (ctx.kernels_q4) {
-            GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
+
+        if (!ctx.kernels_q4) {
+            GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
+        } else {
+            GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
+        }
+
+        if (!ctx.kernels_q8) {
+            GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
+        } else {
+            GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
         }
-        if (ctx.kernels_q8) {
-            GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
+
+        ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
+
+        if (ctx.features & CPU_FEATURE_SME) {
+            if (sme_env_set && sme_env_ok && sme_cores > 0) {
+                GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
+            } else {
+                GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
+            }
+        } else {
+            GGML_LOG_INFO("kleidiai: SME disabled\n");
         }
-#endif
     }
+
     ggml_critical_section_end();
 }
 
+static inline int kleidiai_sme_thread_cap() {
+    return ctx.sme_thread_cap;
+}
+
+static inline size_t align_up(size_t value, size_t alignment) {
+    if (alignment == 0) {
+        return value;
+    }
+    const size_t remainder = value % alignment;
+    return remainder == 0 ? value : value + (alignment - remainder);
+}
+
+static inline bool kleidiai_pack_fallback_allowed() {
+    if (ctx.sme_thread_cap <= 0) {
+        return false;
+    }
+    if (ctx.thread_hint <= 0) {
+        return true;
+    }
+    return ctx.thread_hint > ctx.sme_thread_cap;
+}
+
+struct kleidiai_weight_header {
+    uint32_t magic;
+    uint16_t version;
+    uint16_t slot_count;
+    uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+    uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+};
+
+static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
+    return reinterpret_cast<kleidiai_weight_header *>(data);
+}
+
+static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
+    return reinterpret_cast<const kleidiai_weight_header *>(data);
+}
+
+static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
+    if (!header) {
+        return false;
+    }
+    if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
+        return false;
+    }
+    if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
+        return false;
+    }
+    return true;
+}
+
+static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
+    if (!kleidiai_is_weight_header_valid(header)) {
+        return nullptr;
+    }
+    if (slot < 0 || slot >= header->slot_count) {
+        return nullptr;
+    }
+    return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];
+}
+
+static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
+    if (!kleidiai_is_weight_header_valid(header)) {
+        return nullptr;
+    }
+    if (slot < 0 || slot >= header->slot_count) {
+        return nullptr;
+    }
+    return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];
+}
+
+static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
+    return ctx.kernels_q4;
+}
+
+static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
+    return ctx.kernels_q8;
+}
+
+template <typename SelectFallback>
+static int kleidiai_collect_kernel_chain_common(
+        ggml_kleidiai_kernels * primary,
+        cpu_feature features,
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,
+        SelectFallback select_fallback) {
+    int count = 0;
+    if (!primary) {
+        return 0;
+    }
+    out[count++] = primary;
+
+    if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+        const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);
+        if (fallback_mask != CPU_FEATURE_NONE) {
+            ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
+            if (fallback && fallback != primary &&
+                fallback->lhs_type == primary->lhs_type &&
+                fallback->rhs_type == primary->rhs_type &&
+                fallback->op_type  == primary->op_type) {
+                out[count++] = fallback;
+            }
+        }
+    }
+
+    return count;
+}
+
+static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
+    ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
+    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
+}
+
+static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
+    ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
+    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
+}
+
+static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
+    ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
+    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
+}
+
 static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
     GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
     return tensor->ne[dim];
@@ -126,49 +438,108 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         if (op->op != GGML_OP_MUL_MAT) {
             return false;
         }
-        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
-        if (!kernels) {
+
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
+        if (slot_count == 0) {
             return false;
         }
-        bool is_gemv = op->src[1]->ne[1] == 1;
-        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
-        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
 
-        size_t k = op->src[0]->ne[0];
-        size_t n = op->src[0]->ne[1];
-        size_t m = op->src[1]->ne[1];
-
-        size_t mr = kernel->get_mr();
-        size_t kr = kernel->get_kr();
-        size_t sr = kernel->get_sr();
-
-        if (kernels->rhs_type == GGML_TYPE_Q4_0) {
-            if (!lhs_info->packed_size_ex) return false;
-            size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
-        } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
-            if (!lhs_info->packed_size_ex) return false;
-            size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
-        } else if (kernels->rhs_type == GGML_TYPE_F16) {
-            if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
+        const bool is_gemv = op->src[1]->ne[1] == 1;
+        const size_t k = op->src[0]->ne[0];
+        const size_t n = op->src[0]->ne[1];
+        const size_t m = op->src[1]->ne[1];
+
+        if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
+            const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
+
+            size_t cursor = 0;
+            bool any_slot = false;
+
+            for (int slot = 0; slot < slot_count; ++slot) {
+                ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+                lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+                kernel_info * kernel        = is_gemv ? &kernels->gemv : &kernels->gemm;
+
+                if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
+                    return false;
+                }
+
+                const size_t mr = kernel->get_mr();
+                const size_t kr = kernel->get_kr();
+                const size_t sr = kernel->get_sr();
+
+                const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
+
+                cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+                cursor += packed;
+                any_slot = true;
+            }
+
+            if (!any_slot) {
+                return false;
+            }
+
+            size = cursor;
+            return true;
+        }
+
+        if (op->src[0]->type == GGML_TYPE_F16) {
             const int64_t lhs_batch_size0 = op->src[1]->ne[2];
             const int64_t rhs_batch_size0 = op->src[0]->ne[2];
+            GGML_ASSERT(rhs_batch_size0 > 0);
             const int64_t r = lhs_batch_size0 / rhs_batch_size0;
-            size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
-                   kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
-                   k * n * sizeof(float) + n * sizeof(float);
-        } else {
-            return false;
+
+            size_t cursor = 0;
+            bool any_slot = false;
+
+            for (int slot = 0; slot < slot_count; ++slot) {
+                ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+                lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+                kernel_info * kernel        = is_gemv ? &kernels->gemv : &kernels->gemm;
+                if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
+                    return false;
+                }
+
+                const size_t mr = kernel->get_mr();
+                const size_t kr = kernel->get_kr();
+                const size_t sr = kernel->get_sr();
+
+                cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+                cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
+                any_slot = true;
+            }
+
+            for (int slot = 0; slot < slot_count; ++slot) {
+                ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+                kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+                if (!kernel || !kernels->rhs_info.packed_size_ex) {
+                    return false;
+                }
+                cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+                cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
+            }
+
+            cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+            cursor += k * n * sizeof(float);
+            cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+            cursor += n * sizeof(float);
+
+            if (!any_slot) {
+                return false;
+            }
+
+            size = cursor;
+            return true;
         }
 
-        return true;
+        return false;
     }
 
     bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
         if (dst->op == GGML_OP_MUL_MAT) {
-            if (dst->src[0]->type == GGML_TYPE_Q4_0) {
-                return compute_forward_q4_0(params, dst);
-            } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
-                return compute_forward_q8_0(params, dst);
+            if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
+                return compute_forward_qx(params, dst);
             } else if (dst->src[0]->type == GGML_TYPE_F16) {
                 return compute_forward_fp16(params, dst);
             }
@@ -331,204 +702,457 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         return true;
     }
 
-    bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
-        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+    bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
 
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
-        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
-        if (!kernels) {
-            return false;
-        }
-
-        bool is_gemv = src1->ne[1] == 1;
-        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
-        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
-
-        GGML_ASSERT(kernel);
-        if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
-            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
-            return false;
-        }
+        const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
+        const bool has_header = kleidiai_is_weight_header_valid(header);
+        const bool is_gemv = src1->ne[1] == 1;
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
 
-        const int ith = params->ith;
-        const int nth_raw = params->nth;
-        const int nth = nth_raw > 0 ? nth_raw : 1;
+        auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
+            if (slot_index < 0 || slot_index >= slot_total) {
+                return nullptr;
+            }
+            if (has_header) {
+                if (slot_index < header->slot_count) {
+                    size_out = static_cast<size_t>(header->sizes[slot_index]);
+                    return kleidiai_weight_slot_ptr(header, slot_index);
+                }
+                return nullptr;
+            }
+            if (slot_index == 0) {
+                size_out = ggml_nbytes(src0);
+                return static_cast<const uint8_t *>(src0->data);
+            }
+            return nullptr;
+        };
+
+        struct runtime_slot {
+            int slot_index;
+            ggml_kleidiai_kernels * kernels;
+            kernel_info * kernel;
+            lhs_packing_info * lhs_info;
+            size_t mr;
+            size_t nr;
+            size_t kr;
+            size_t sr;
+            size_t n_step;
+            size_t lhs_packed_size;
+            size_t lhs_offset;
+            size_t n_offset;
+            size_t n_cols;
+            int assigned_threads;
+            int thread_begin;
+            int thread_end;
+            const uint8_t * rhs_base;
+        };
+
+        std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};
+        int runtime_count = 0;
+
+        for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
+            ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+            kernel_info * kinfo      = is_gemv ? &kernels->gemv : &kernels->gemm;
+            lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+            if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
+                !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
+                continue;
+            }
 
-        const size_t k = ne00;
-        const size_t m = ne11;
-        const size_t n = ne01;
+            size_t rhs_size = 0;
+            const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
+            if (!rhs_ptr || rhs_size == 0) {
+                continue;
+            }
 
-        size_t mr = kernel->get_mr();
-        size_t kr = kernel->get_kr();
-        size_t sr = kernel->get_sr();
+            runtime[runtime_count] = {
+                slot,
+                kernels,
+                kinfo,
+                linfo,
+                kinfo->get_mr(),
+                kinfo->get_nr(),
+                kinfo->get_kr(),
+                kinfo->get_sr(),
+                kinfo->get_n_step(),
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                rhs_ptr
+            };
+            ++runtime_count;
+        }
 
-        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
-        uint8_t * lhs_packed       = (uint8_t*)params->wdata;
-        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+        if (runtime_count == 0) {
+            ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
+            if (!fallback) {
+                return false;
+            }
+            kernel_info * kinfo      = is_gemv ? &fallback->gemv : &fallback->gemm;
+            lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
+            rhs_packing_info * rinfo = &fallback->rhs_info;
+            if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
+                !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
+                !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
+                return false;
+            }
+            kernel_chain[0] = fallback;
+            runtime[0] = {
+                0,
+                fallback,
+                kinfo,
+                linfo,
+                kinfo->get_mr(),
+                kinfo->get_nr(),
+                kinfo->get_kr(),
+                kinfo->get_sr(),
+                kinfo->get_n_step(),
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                nullptr
+            };
+            size_t rhs_size_fallback = 0;
+            const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
+            if (!rhs_base) {
+                rhs_base = static_cast<const uint8_t *>(src0->data);
+            }
+            runtime[0].rhs_base = rhs_base;
+            runtime_count = 1;
+        }
 
-        const size_t n_step = kernel->get_n_step();
-        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
-        const size_t n_start = ith * num_n_per_thread;
+        const int nth_total = params->nth > 0 ? params->nth : 1;
+        const int ith_total = params->ith;
 
-        size_t n_to_process = 0;
-        if (n_start < n) {
-            n_to_process = num_n_per_thread;
-            if ((n_start + n_to_process) > n) {
-                n_to_process = n - n_start;
+        int sme_slot = -1;
+        for (int i = 0; i < runtime_count; ++i) {
+            if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+                sme_slot = i;
+                break;
             }
         }
 
-        // Calculate number of columns to be processed per thread
-        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
-        const size_t m_start = ith * num_m_per_thread;
-        size_t m_to_process = num_m_per_thread;
-        if ((m_start + m_to_process) > m) {
-            m_to_process = m - m_start;
+        const int sme_cap_limit = ctx.sme_thread_cap;
+        const bool use_hybrid = sme_cap_limit > 0 &&
+                                 runtime_count > 1 &&
+                                 nth_total > sme_cap_limit;
+        // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
+        // If rows are small or average columns per thread are small, keep single-slot.
+        size_t min_cols_per_thread = 0;
+        if (runtime_count > 0 && nth_total > 0) {
+            min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);
         }
+        const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
 
-        if (m_start < m) {
-            // Transform LHS
-            const size_t src_stride        = src1->nb[1];
-            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
-            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
-            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
-
-            // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
-            lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
-        }
+        const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
 
-        ggml_barrier(params->threadpool);
+        if (!hybrid_enabled) {
+            int chosen_slot = 0;
+            if (too_small_for_hybrid && sme_slot != -1) {
+                chosen_slot = sme_slot;
+            } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
+                chosen_slot = 1;
+            }
+            if (chosen_slot != 0 && chosen_slot < runtime_count) {
+                runtime[0] = runtime[chosen_slot];
+            }
+            runtime_count = runtime_count > 0 ? 1 : 0;
 
-        // Perform the operation
-        const size_t dst_stride        = dst->nb[1];
-        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
-        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
-        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
-        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
-        const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
-        float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+            // Recompute SME slot based on the collapsed runtime[0]
+            sme_slot = -1;
+            if (runtime_count > 0 &&
+                (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+                sme_slot = 0;
+            }
+        }
 
-        if (n_to_process > 0) {
-            kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
-                               sizeof(float), -FLT_MAX, FLT_MAX);
+        int sme_cap = kleidiai_sme_thread_cap();
+        if (sme_cap < 0) {
+            sme_cap = nth_total;
         }
+        sme_cap = std::min(sme_cap, nth_total);
 
-        return true;
-    }
+        int threads_remaining = nth_total;
+        if (sme_slot != -1) {
+            int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
+            runtime[sme_slot].assigned_threads = sme_threads;
+            threads_remaining -= sme_threads;
+        }
 
-    bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
-        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
+        int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+        int fallback_count = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            if (i == sme_slot) {
+                continue;
+            }
+            fallback_indices[fallback_count++] = i;
+        }
 
-        const ggml_tensor * src0 = dst->src[0];
-        const ggml_tensor * src1 = dst->src[1];
+        for (int fi = 0; fi < fallback_count; ++fi) {
+            if (threads_remaining <= 0) {
+                break;
+            }
+            const int slot_index = fallback_indices[fi];
+            const int slots_left = fallback_count - fi;
+            int share = (threads_remaining + slots_left - 1) / slots_left;
+            share     = std::min(share, threads_remaining);
+            runtime[slot_index].assigned_threads = share;
+            threads_remaining -= share;
+        }
 
-        GGML_TENSOR_BINARY_OP_LOCALS
+        if (threads_remaining > 0) {
+            const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
+            runtime[fallback_slot].assigned_threads += threads_remaining;
+            threads_remaining = 0;
+        }
 
-        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
-        if (!kernels) {
-            return false;
+        int thread_cursor = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            runtime[i].thread_begin = thread_cursor;
+            thread_cursor += runtime[i].assigned_threads;
+            runtime[i].thread_end = thread_cursor;
         }
 
-        bool is_gemv = src1->ne[1] == 1;
-        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
-        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+        if (thread_cursor < nth_total && runtime_count > 0) {
+            runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
+            runtime[runtime_count - 1].thread_end = nth_total;
+        }
 
-        if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
-            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
+        int local_slot = -1;
+        int local_ith  = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
+                local_slot = i;
+                local_ith  = ith_total - runtime[i].thread_begin;
+                break;
+            }
+        }
+        if (local_slot == -1) {
             return false;
         }
 
-        const int ith = params->ith;
-        const int nth_raw = params->nth;
-        const int nth = nth_raw > 0 ? nth_raw : 1;
-
         const size_t k = ne00;
         const size_t m = ne11;
         const size_t n = ne01;
 
-        size_t mr = kernel->get_mr();
-        size_t kr = kernel->get_kr();
-        size_t sr = kernel->get_sr();
-
-        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
-        uint8_t * lhs_packed       = static_cast<uint8_t *>(params->wdata);
-        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+        size_t cursor = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
+            const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                              slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
+            runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
+            cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+            runtime[i].lhs_offset = cursor;
+            cursor += runtime[i].lhs_packed_size;
+        }
 
-        const size_t n_step = kernel->get_n_step();
-        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
-        const size_t n_start = ith * num_n_per_thread;
+        GGML_ASSERT(cursor <= params->wsize);
+        uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
 
-        size_t n_to_process = 0;
-        if (n_start < n) {
-            n_to_process = num_n_per_thread;
-            if ((n_start + n_to_process) > n) {
-                n_to_process = n - n_start;
+        size_t assigned_cols = 0;
+        uint64_t weighted_total = 0;
+        if (runtime_count > 1 && sme_slot != -1) {
+            for (int i = 0; i < runtime_count; ++i) {
+                const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
+                weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
             }
         }
+        for (int i = 0; i < runtime_count; ++i) {
+            runtime[i].n_offset = assigned_cols;
+            if (runtime[i].assigned_threads == 0) {
+                runtime[i].n_cols = 0;
+                continue;
+            }
+            const size_t remaining_cols = n - assigned_cols;
+            if (remaining_cols == 0) {
+                runtime[i].n_cols = 0;
+                continue;
+            }
+            const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
+            size_t target      = 0;
+            if (weighted_total > 0) {
+                const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
+                target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
+            } else {
+                target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
+            }
+            target             = std::min(target, remaining_cols);
+            size_t aligned     = round_down(target, step);
+            if (aligned == 0 && remaining_cols >= step) {
+                aligned = step;
+            }
+            runtime[i].n_cols = aligned;
+            assigned_cols += aligned;
+        }
 
-        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
-        const size_t m_start = ith * num_m_per_thread;
-        size_t m_to_process = num_m_per_thread;
-        if ((m_start + m_to_process) > m) {
-            m_to_process = m - m_start;
+        if (assigned_cols < n) {
+            for (int i = runtime_count - 1; i >= 0; --i) {
+                if (runtime[i].assigned_threads > 0) {
+                    runtime[i].n_cols += n - assigned_cols;
+                    break;
+                }
+            }
         }
+        const size_t dst_stride = dst->nb[1];
 
-        if (m_start < m) {
-            const size_t src_stride        = src1->nb[1];
-            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
-            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
-            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
+        for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
+            const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
+            uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
 
-            lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
-        }
+            if (runtime[local_slot].assigned_threads > 0) {
+                runtime_slot & slot = runtime[local_slot];
+                const ggml_type slot_rhs_type = slot.kernels->rhs_type;
+                const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                                 slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+                const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
+                int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
+                max_threads = std::max<int64_t>(1, max_threads);
+                const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);
 
-        ggml_barrier(params->threadpool);
+                if (local_ith < use_threads) {
+                    const int64_t num_m_per_thread0   = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
+                    const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
 
-        const size_t dst_stride        = dst->nb[1];
-        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
-        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
-        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
-        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
-        const void * lhs_ptr           = static_cast<const void *>(lhs_packed + lhs_packed_offset);
-        float * dst_ptr                = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+                    const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
+                    const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
+
+                    const size_t base_packed_off  = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+                    const size_t next_block_off   = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+                    const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
+
+                    int64_t remaining = m_count;
+                    int64_t cur       = m_start;
+
+                    uint8_t * lhs_packed = scratch + slot.lhs_offset;
+                    while (remaining > 0) {
+                        const int64_t row_in_group = cur;
+                        const int64_t avail        = (int64_t)m - row_in_group;
+                        const int64_t take         = std::min(avail, remaining);
+
+                        const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
+                        const void * src_ptr = lhs_batch_base + src_off;
+                        const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
+                        void * dst_ptr       = lhs_packed + dst_off;
+
+                        slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
+
+                        cur       += take;
+                        remaining -= take;
+                    }
+                }
+            }
+
+            ggml_barrier(params->threadpool);
 
-        if (n_to_process > 0) {
-            kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
-                                  sizeof(float), -FLT_MAX, FLT_MAX);
+            runtime_slot & slot = runtime[local_slot];
+            if (slot.n_cols > 0 && slot.assigned_threads > 0) {
+                int64_t active_threads = slot.assigned_threads;
+                const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
+                if (max_threads > 0) {
+                    active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));
+                }
+                active_threads = std::max<int64_t>(1, active_threads);
+
+                if (local_ith < active_threads) {
+                    const size_t step = slot.n_step ? slot.n_step : 1;
+                    const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
+                    const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
+                    const size_t local_start = (size_t)local_ith * chunk0;
+                    const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
+
+                    if (cols > 0) {
+                        const ggml_type slot_rhs_type = slot.kernels->rhs_type;
+                        const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                                         slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+                        const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                                          slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+                        const size_t global_start = slot.n_offset + local_start;
+                        const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+                        const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
+                        const size_t dst_offset        = slot.kernel->get_dst_offset(0, global_start, dst_stride);
+
+                        const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
+                        const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
+                        float * dst_ptr         = reinterpret_cast<float *>(dst_batch_base + dst_offset);
+
+                        slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
+                                                   lhs_ptr,
+                                                   rhs_ptr,
+                                                   dst_ptr,
+                                                   dst_stride,
+                                                   sizeof(float),
+                                                   -FLT_MAX,
+                                                   FLT_MAX);
+                    }
+                }
+            }
+
+            if (batch_idx != ne12 - 1) {
+                ggml_barrier(params->threadpool);
+            }
         }
 
         return true;
     }
 
     bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
+        const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
+        const bool has_header = kleidiai_is_weight_header_valid(header);
+
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
+        const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+                                        : kleidiai_collect_q4_chain(kernel_chain);
+
         ggml_kleidiai_kernels * kernels = nullptr;
-        size_t block_len = 0;
-        size_t num_bytes_multiplier = 0;
+        const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);
 
-        if (dst->src[0]->type == GGML_TYPE_Q4_0) {
-            if (!ctx.kernels_q4) {
-                return false;
+        if (has_header && chain_count > 0) {
+            int select_slot = 0;
+            if (select_slot >= header->slot_count) {
+                select_slot = header->slot_count - 1;
             }
-            kernels = ctx.kernels_q4;
-            block_len = QK4_0;
-            num_bytes_multiplier = sizeof(uint16_t);
-        } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
-            if (!ctx.kernels_q8) {
-                return false;
+            if (select_slot >= 0 && select_slot < chain_count) {
+                kernels = kernel_chain[select_slot];
+                const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
+                if (slot_ptr) {
+                    packed_base = slot_ptr;
+                }
             }
-            kernels = ctx.kernels_q8;
-            block_len = QK8_0;
-            num_bytes_multiplier = sizeof(float);
-        } else {
+        }
+
+        if (!kernels && chain_count > 0) {
+            kernels = kernel_chain[0];
+            if (has_header) {
+                const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
+                if (slot_ptr) {
+                    packed_base = slot_ptr;
+                }
+            }
+        }
+
+        if (!kernels) {
             return false;
         }
 
@@ -541,6 +1165,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         const int64_t nc     = ne00;
         const int64_t nr     = ggml_nelements(src1);
 
+        const ggml_type rhs_type = kernels->rhs_type;
+        size_t block_len = 0;
+        size_t num_bytes_multiplier = 0;
+        if (rhs_type == GGML_TYPE_Q4_0) {
+            block_len = QK4_0;
+            num_bytes_multiplier = sizeof(uint16_t);
+        } else if (rhs_type == GGML_TYPE_Q8_0) {
+            block_len = QK8_0;
+            num_bytes_multiplier = sizeof(float);
+        } else {
+            return false;
+        }
+
         const size_t block_rows = kernel->get_nr();
         const size_t kr         = kernel->get_kr();
 
@@ -559,7 +1196,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
             GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
 
             float *out = (float *)((char *)dst->data + i * nb1);
-            rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
+            rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
         }
 
         return true;
@@ -567,36 +1204,39 @@ class tensor_traits : public ggml::cpu::tensor_traits {
 
 public:
     int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
+        GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
         const size_t n = tensor->ne[1];
         const size_t k = tensor->ne[0];
 
-        if (tensor->type == GGML_TYPE_Q4_0) {
-            if (!ctx.kernels_q4) {
-                return -1;
-            }
-            size_t nr = ctx.kernels_q4->gemm.get_nr();
-            size_t kr = ctx.kernels_q4->gemm.get_kr();
-            size_t sr = ctx.kernels_q4->gemm.get_sr();
+        kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
+        if (!header) {
+            return -1;
+        }
 
-            struct kai_rhs_pack_qs4cxs1s0_param params;
-            params.lhs_zero_point = 1;
-            params.rhs_zero_point = 8;
-            ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
-                                                  static_cast<const uint8_t *>(data),
-                                                  nullptr, nullptr, tensor->data, 0, &params);
-            GGML_UNUSED(data_size);
-            return 0;
-        } else if (tensor->type == GGML_TYPE_Q8_0) {
-            if (!ctx.kernels_q8) {
-                return -1;
-            }
+        header->magic      = GGML_KLEIDIAI_PACK_MAGIC;
+        header->version    = GGML_KLEIDIAI_PACK_VERSION;
+        header->slot_count = 0;
+
+        uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);
+        size_t cursor = sizeof(kleidiai_weight_header);
+        cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
+        const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+                                       : kleidiai_collect_q4_chain(kernel_chain);
+        const bool allow_fallback = kleidiai_pack_fallback_allowed();
+
+        std::vector<int8_t> qdata;
+        std::vector<float>  scales;
+
+        if (want_q8 && slot_total > 0) {
+            qdata.resize(n * k, 0);
+            scales.resize(n, 0.0f);
 
             const size_t row_stride = tensor->nb[1];
             const size_t k_blocks   = (k + QK8_0 - 1) / QK8_0;
 
-            std::vector<int8_t> qdata(n * k, 0);
-            std::vector<float> scales(n, 0.0f);
-
             for (size_t row = 0; row < n; ++row) {
                 const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
                     static_cast<const uint8_t *>(data) + row * row_stride);
@@ -610,7 +1250,7 @@ public:
                         if (linear_idx >= k) {
                             break;
                         }
-                        const float value = d * blk.qs[l];
+                        const float value = d * static_cast<float>(blk.qs[l]);
                         max_abs = std::max(max_abs, std::fabs(value));
                     }
                 }
@@ -627,31 +1267,73 @@ public:
                         if (linear_idx >= k) {
                             break;
                         }
-                        const float value = d * blk.qs[l];
+                        const float value = d * static_cast<float>(blk.qs[l]);
                         int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
                         q = std::clamp(q, -127, 127);
                         qdata[row * k + linear_idx] = static_cast<int8_t>(q);
                     }
                 }
             }
+        }
+
+        for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
+            if (!allow_fallback && slot > 0) {
+                break;
+            }
+            ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+            kernel_info * kernel = &kernels->gemm;
+            rhs_packing_info * rhs_info = &kernels->rhs_info;
+            if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
+                continue;
+            }
+
+            const size_t nr = kernel->get_nr();
+            const size_t kr = kernel->get_kr();
+            const size_t sr = kernel->get_sr();
+            const ggml_type rhs_type = kernels->rhs_type;
+            const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
+                                     rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
+            if (block_len == 0) {
+                continue;
+            }
 
-            size_t nr = ctx.kernels_q8->gemm.get_nr();
-            size_t kr = ctx.kernels_q8->gemm.get_kr();
-            size_t sr = ctx.kernels_q8->gemm.get_sr();
+            const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
+            const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+            uint8_t * dst_ptr = base_ptr + aligned_cursor;
+
+            if (rhs_type == GGML_TYPE_Q4_0) {
+                struct kai_rhs_pack_qs4cxs1s0_param params;
+                params.lhs_zero_point = 1;
+                params.rhs_zero_point = 8;
+                rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
+                                       static_cast<const uint8_t *>(data), nullptr, nullptr,
+                                       dst_ptr, 0, &params);
+            } else if (rhs_type == GGML_TYPE_Q8_0) {
+                struct kai_rhs_pack_qsi8cx_params params;
+                params.lhs_zero_point = 1;
+                params.scale_multiplier = 1.0f;
+                rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
+                                       qdata.data(), nullptr, scales.data(),
+                                       dst_ptr, 0, &params);
+            } else {
+                continue;
+            }
+
+            header->offsets[header->slot_count] = aligned_cursor;
+            header->sizes[header->slot_count]   = packed_size;
+            ++header->slot_count;
 
-            struct kai_rhs_pack_qsi8cx_params params;
-            params.lhs_zero_point = 1;
-            params.scale_multiplier = 1.0f;
+            cursor = aligned_cursor + packed_size;
+        }
 
-            ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
-                                                  qdata.data(), nullptr, scales.data(),
-                                                  tensor->data, 0, &params);
-            GGML_UNUSED(data_size);
-            return 0;
+        if (header->slot_count == 0) {
+            header->magic   = 0;
+            header->version = 0;
+            memcpy(tensor->data, data, data_size);
         }
 
-        GGML_UNUSED(data_size);
-        return -1;
+        return 0;
     }
 };
 
@@ -681,9 +1363,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu
 }
 
 static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    return "CPU_KLEIDIAI";
-
     GGML_UNUSED(buft);
+    return "CPU_KLEIDIAI";
 }
 
 static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -702,49 +1383,78 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(
 }
 
 static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return TENSOR_ALIGNMENT;
-
     GGML_UNUSED(buft);
+    return TENSOR_ALIGNMENT;
 }
 
 static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
     GGML_UNUSED(buft);
 
+    if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
+        return ggml_nbytes(tensor);
+    }
+
     const size_t n = tensor->ne[1];
     const size_t k = tensor->ne[0];
 
-    ggml_kleidiai_kernels * kernels = nullptr;
-    size_t block_len = 0;
-
-    if (tensor->type == GGML_TYPE_Q4_0) {
-        GGML_ASSERT(ctx.kernels_q4);
-        kernels = ctx.kernels_q4;
-        block_len = QK4_0;
-    } else if (tensor->type == GGML_TYPE_Q8_0) {
-        GGML_ASSERT(ctx.kernels_q8);
-        kernels = ctx.kernels_q8;
-        block_len = QK8_0;
-    } else {
-        return 0;
+    size_t cursor = sizeof(kleidiai_weight_header);
+    cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+    std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+    const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
+    const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+                                   : kleidiai_collect_q4_chain(kernel_chain);
+    const bool allow_fallback = kleidiai_pack_fallback_allowed();
+
+    size_t slot_count = 0;
+    for (int slot = 0; slot < slot_total; ++slot) {
+        if (!allow_fallback && slot > 0) {
+            break;
+        }
+        ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+        if (!kernels) {
+            continue;
+        }
+        kernel_info * kernel = &kernels->gemm;
+        rhs_packing_info * rhs_info = &kernels->rhs_info;
+        if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
+            continue;
+        }
+
+        const ggml_type rhs_type = kernels->rhs_type;
+        const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                 rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
+        if (block_len == 0) {
+            continue;
+        }
+
+        cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+        cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
+        ++slot_count;
     }
 
-    const size_t nr = kernels->gemm.get_nr();
-    const size_t kr = kernels->gemm.get_kr();
-    const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
-    const size_t raw     = ggml_nbytes(tensor);
+    if (slot_count == 0) {
+        return ggml_nbytes(tensor);
+    }
 
-    return packed > raw ? packed : raw;
+    return std::max(cursor, ggml_nbytes(tensor));
 }
 
 namespace ggml::cpu::kleidiai {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
         if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
             (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
             op->src[0]->buffer &&
             (ggml_n_dims(op->src[0]) == 2) &&
-            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
-            if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
+            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
+            slot_total > 0) {
+            if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
+                return false;
+            }
+            if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
                 return false;
             }
             if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
@@ -762,14 +1472,17 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
         if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
             if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
                 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
-            }
-            else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
-                if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
-                    (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
-                    return nullptr;
+            } else {
+                std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+                const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
+                const bool has_kernel = slot_total > 0;
+                if (has_kernel && op->src[1]->ne[1] > 1) {
+                    if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
+                        (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
+                        return nullptr;
+                    }
+                    return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
                 }
-
-                return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
             }
         }
         return nullptr;