]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
kleidiai : support for concurrent sme and neon kernel execution (llama/20070)
authorCharles Xu <redacted>
Tue, 10 Mar 2026 07:25:25 +0000 (08:25 +0100)
committerGeorgi Gerganov <redacted>
Sun, 15 Mar 2026 19:50:13 +0000 (21:50 +0200)
src/ggml-cpu/kleidiai/kernels.cpp
src/ggml-cpu/kleidiai/kleidiai.cpp

index 40f7c0df6505c6c03e19435effe76be39c972610..8c4d7bc925f6ab80ecf7e376ab866504ed2b2c53 100644 (file)
@@ -520,7 +520,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
             /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
         },
-        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .required_cpu       = */ CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
         /* .rhs_type           = */ GGML_TYPE_Q4_0,
         /* .op_type            = */ GGML_TYPE_F32,
@@ -631,7 +631,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
             /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
         },
-        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .required_cpu       = */ CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
         /* .rhs_type           = */ GGML_TYPE_Q4_0,
         /* .op_type            = */ GGML_TYPE_F32,
@@ -801,7 +801,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
             /* .pack_func_ex          = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
         },
-        /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
+        /* .required_cpu       = */ CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
         /* .rhs_type           = */ GGML_TYPE_Q8_0,
         /* .op_type            = */ GGML_TYPE_F32,
index ad23e73184e5e6255eb6b3f67fc87340e121bff6..9bcc18d442c1fc055318b3e9d8bae3f8aa901d93 100644 (file)
@@ -1,20 +1,31 @@
-// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
 // SPDX-License-Identifier: MIT
 //
 #include <arm_neon.h>
 #include <assert.h>
+#include <stdio.h>
 #include <atomic>
 #include <cfloat>
-#include <cmath>
 #include <algorithm>
+#include <cmath>
 #include <stdexcept>
 #include <stdint.h>
 #include <string.h>
 #include <string>
 #include <vector>
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <fstream>
+#include <set>
+#include <iostream>
+#include <climits>
 #if defined(__linux__)
 #include <asm/hwcap.h>
 #include <sys/auxv.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
 #elif defined(__APPLE__)
 #include <string_view>
 #include <sys/sysctl.h>
 #define GGML_COMMON_DECL_CPP
 #include "ggml-common.h"
 
+static constexpr int      GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
+static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC       = 0x4b4c4149; // "KLAI"
+static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION     = 1;
+static constexpr size_t   GGML_KLEIDIAI_PACK_ALIGN       = 64;
+
 struct ggml_kleidiai_context {
     cpu_feature features;
     ggml_kleidiai_kernels * kernels_q4;
     ggml_kleidiai_kernels * kernels_q8;
-} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
+    int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
+    int thread_hint;    // <= 0 means “no hint”
+} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
 
 static const char* cpu_feature_to_string(cpu_feature f) {
     if (f == CPU_FEATURE_NONE) {
@@ -63,41 +81,335 @@ static const char* cpu_feature_to_string(cpu_feature f) {
     }
 }
 
-static void init_kleidiai_context(void) {
+static size_t detect_num_smcus() {
+    if (!ggml_cpu_has_sme()) {
+        return 0;
+    }
+
+#if defined(__linux__) && defined(__aarch64__)
+    // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
+    size_t num_private = 0;
+    std::set<uint32_t> shared_ids;
+
+    for (size_t cpu = 0;; ++cpu) {
+        const std::string path =
+            "/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
+            "/regs/identification/smidr_el1";
+
+        std::ifstream file(path);
+        if (!file.is_open()) {
+            break;
+        }
+
+        uint64_t smidr = 0;
+        if (!(file >> std::hex >> smidr)) {
+            continue;
+        }
+
+        // Arm ARM: SMIDR_EL1
+        const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
+        // Build an "affinity-like" identifier for shared SMCUs.
+        // Keep the original packing logic, but isolate it here.
+        const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
+
+        switch (sh) {
+            case 0b10: // private SMCU
+                ++num_private;
+                break;
+            case 0b11: // shared SMCU
+                shared_ids.emplace(id);
+                break;
+            case 0b00:
+                // Ambiguous / implementation-defined. Be conservative:
+                // treat id==0 as private, otherwise as shared.
+                if (id == 0) ++num_private;
+                else shared_ids.emplace(id);
+                break;
+            default:
+                break;
+        }
+    }
+
+    return num_private + shared_ids.size();
+
+#elif defined(__APPLE__) && defined(__aarch64__)
+    // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.
+    char chip_name[256] = {};
+    size_t size = sizeof(chip_name);
+
+    if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
+        const std::string brand(chip_name);
+
+        struct ModelSMCU { const char *match; size_t smcus; };
+        static const ModelSMCU table[] = {
+            { "M4 Ultra", 2 },
+            { "M4 Max",   2 },
+            { "M4 Pro",   2 },
+            { "M4",       1 },
+        };
 
+        for (const auto &e : table) {
+            if (brand.find(e.match) != std::string::npos) {
+                return e.smcus;
+            }
+        }
+    }
+    return 1;
+
+#else
+    return 1;
+#endif
+}
+
+static int parse_uint_env(const char *s, const char *name, bool *ok) {
+    if (!s) { *ok = false; return 0; }
+    char *end = nullptr;
+    long v = strtol(s, &end, 10);
+    if (end == s || *end != '\0') {
+        GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
+        *ok = false;
+        return 0;
+    }
+    if (v < 0 || v > INT_MAX) {
+        GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
+        *ok = false;
+        return 0;
+    }
+    *ok = true;
+    return (int)v;
+}
+
+static void init_kleidiai_context(void) {
     ggml_critical_section_start();
     static bool initialized = false;
 
     if (!initialized) {
         initialized = true;
-        const char *env_var = getenv("GGML_KLEIDIAI_SME");
-        int sme_enabled = 0;
+
+        const char *env_sme     = getenv("GGML_KLEIDIAI_SME");
+        const char *env_threads = getenv("GGML_TOTAL_THREADS");
+
+        const bool cpu_has_sme = ggml_cpu_has_sme();
+        size_t detected_smcus = 0;
 
         ctx.features  = (ggml_cpu_has_dotprod()     ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
                         (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM    : CPU_FEATURE_NONE) |
                         ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
 
-        if (env_var) {
-            sme_enabled = atoi(env_var);
+        if (env_threads) {
+            bool ok = false;
+            int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
+            if (ok && hint > 0) {
+                ctx.thread_hint = hint;
+            }
         }
 
-        if (sme_enabled != 0) {
-            ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
+        // SME policy:
+        // - If CPU doesn't support SME: SME always off.
+        // - Else:
+        //   - env unset => auto-detect cores; enable if detected > 0.
+        //   - env=0     => force off.
+        //   - env>0     => force N cores (skip detection).
+        int sme_cores = 0;
+        bool sme_env_ok = false;
+        bool sme_env_set = (env_sme != nullptr);
+
+        if (!cpu_has_sme) {
+            if (sme_env_set) {
+                bool ok = false;
+                int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
+                if (ok && req > 0) {
+                    GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
+                }
+            }
+            sme_cores = 0;
+        } else {
+            if (sme_env_set) {
+                bool ok = false;
+                int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
+                sme_env_ok = ok;
+
+                if (!ok) {
+                    GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
+                    detected_smcus = detect_num_smcus();
+                    sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
+                } else if (v == 0) {
+                    sme_cores = 0;
+                } else {
+                    sme_cores = v;
+                }
+            } else {
+                detected_smcus = detect_num_smcus();
+                sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
+            }
+
+            if (!sme_env_set && sme_cores == 0) {
+                GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
+            }
+
+            if (sme_cores > 0) {
+                ctx.features |= CPU_FEATURE_SME;
+            }
         }
+
+        // Kernel selection
         ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
         ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
-#ifndef NDEBUG
-        if (ctx.kernels_q4) {
-            GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
+
+        if (!ctx.kernels_q4) {
+            GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
+        } else {
+            GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
+        }
+
+        if (!ctx.kernels_q8) {
+            GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
+        } else {
+            GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
         }
-        if (ctx.kernels_q8) {
-            GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
+
+        ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
+
+        if (ctx.features & CPU_FEATURE_SME) {
+            if (sme_env_set && sme_env_ok && sme_cores > 0) {
+                GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
+            } else {
+                GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
+            }
+        } else {
+            GGML_LOG_INFO("kleidiai: SME disabled\n");
         }
-#endif
     }
+
     ggml_critical_section_end();
 }
 
+static inline int kleidiai_sme_thread_cap() {
+    return ctx.sme_thread_cap;
+}
+
+static inline size_t align_up(size_t value, size_t alignment) {
+    if (alignment == 0) {
+        return value;
+    }
+    const size_t remainder = value % alignment;
+    return remainder == 0 ? value : value + (alignment - remainder);
+}
+
+static inline bool kleidiai_pack_fallback_allowed() {
+    if (ctx.sme_thread_cap <= 0) {
+        return false;
+    }
+    if (ctx.thread_hint <= 0) {
+        return true;
+    }
+    return ctx.thread_hint > ctx.sme_thread_cap;
+}
+
+struct kleidiai_weight_header {
+    uint32_t magic;
+    uint16_t version;
+    uint16_t slot_count;
+    uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+    uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+};
+
+static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
+    return reinterpret_cast<kleidiai_weight_header *>(data);
+}
+
+static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
+    return reinterpret_cast<const kleidiai_weight_header *>(data);
+}
+
+static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
+    if (!header) {
+        return false;
+    }
+    if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
+        return false;
+    }
+    if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
+        return false;
+    }
+    return true;
+}
+
+static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
+    if (!kleidiai_is_weight_header_valid(header)) {
+        return nullptr;
+    }
+    if (slot < 0 || slot >= header->slot_count) {
+        return nullptr;
+    }
+    return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];
+}
+
+static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
+    if (!kleidiai_is_weight_header_valid(header)) {
+        return nullptr;
+    }
+    if (slot < 0 || slot >= header->slot_count) {
+        return nullptr;
+    }
+    return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];
+}
+
+static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
+    return ctx.kernels_q4;
+}
+
+static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
+    return ctx.kernels_q8;
+}
+
+template <typename SelectFallback>
+static int kleidiai_collect_kernel_chain_common(
+        ggml_kleidiai_kernels * primary,
+        cpu_feature features,
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,
+        SelectFallback select_fallback) {
+    int count = 0;
+    if (!primary) {
+        return 0;
+    }
+    out[count++] = primary;
+
+    if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+        const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);
+        if (fallback_mask != CPU_FEATURE_NONE) {
+            ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
+            if (fallback && fallback != primary &&
+                fallback->lhs_type == primary->lhs_type &&
+                fallback->rhs_type == primary->rhs_type &&
+                fallback->op_type  == primary->op_type) {
+                out[count++] = fallback;
+            }
+        }
+    }
+
+    return count;
+}
+
+static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
+    ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
+    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
+}
+
+static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
+    ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
+    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
+}
+
+static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
+    ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
+    return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+        [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
+}
+
 static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
     GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
     return tensor->ne[dim];
@@ -126,49 +438,108 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         if (op->op != GGML_OP_MUL_MAT) {
             return false;
         }
-        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
-        if (!kernels) {
+
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
+        if (slot_count == 0) {
             return false;
         }
-        bool is_gemv = op->src[1]->ne[1] == 1;
-        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
-        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
 
-        size_t k = op->src[0]->ne[0];
-        size_t n = op->src[0]->ne[1];
-        size_t m = op->src[1]->ne[1];
-
-        size_t mr = kernel->get_mr();
-        size_t kr = kernel->get_kr();
-        size_t sr = kernel->get_sr();
-
-        if (kernels->rhs_type == GGML_TYPE_Q4_0) {
-            if (!lhs_info->packed_size_ex) return false;
-            size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
-        } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
-            if (!lhs_info->packed_size_ex) return false;
-            size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
-        } else if (kernels->rhs_type == GGML_TYPE_F16) {
-            if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
+        const bool is_gemv = op->src[1]->ne[1] == 1;
+        const size_t k = op->src[0]->ne[0];
+        const size_t n = op->src[0]->ne[1];
+        const size_t m = op->src[1]->ne[1];
+
+        if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
+            const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
+
+            size_t cursor = 0;
+            bool any_slot = false;
+
+            for (int slot = 0; slot < slot_count; ++slot) {
+                ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+                lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+                kernel_info * kernel        = is_gemv ? &kernels->gemv : &kernels->gemm;
+
+                if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
+                    return false;
+                }
+
+                const size_t mr = kernel->get_mr();
+                const size_t kr = kernel->get_kr();
+                const size_t sr = kernel->get_sr();
+
+                const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
+
+                cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+                cursor += packed;
+                any_slot = true;
+            }
+
+            if (!any_slot) {
+                return false;
+            }
+
+            size = cursor;
+            return true;
+        }
+
+        if (op->src[0]->type == GGML_TYPE_F16) {
             const int64_t lhs_batch_size0 = op->src[1]->ne[2];
             const int64_t rhs_batch_size0 = op->src[0]->ne[2];
+            GGML_ASSERT(rhs_batch_size0 > 0);
             const int64_t r = lhs_batch_size0 / rhs_batch_size0;
-            size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
-                   kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
-                   k * n * sizeof(float) + n * sizeof(float);
-        } else {
-            return false;
+
+            size_t cursor = 0;
+            bool any_slot = false;
+
+            for (int slot = 0; slot < slot_count; ++slot) {
+                ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+                lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+                kernel_info * kernel        = is_gemv ? &kernels->gemv : &kernels->gemm;
+                if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
+                    return false;
+                }
+
+                const size_t mr = kernel->get_mr();
+                const size_t kr = kernel->get_kr();
+                const size_t sr = kernel->get_sr();
+
+                cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+                cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
+                any_slot = true;
+            }
+
+            for (int slot = 0; slot < slot_count; ++slot) {
+                ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+                kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+                if (!kernel || !kernels->rhs_info.packed_size_ex) {
+                    return false;
+                }
+                cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+                cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
+            }
+
+            cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+            cursor += k * n * sizeof(float);
+            cursor  = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+            cursor += n * sizeof(float);
+
+            if (!any_slot) {
+                return false;
+            }
+
+            size = cursor;
+            return true;
         }
 
-        return true;
+        return false;
     }
 
     bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
         if (dst->op == GGML_OP_MUL_MAT) {
-            if (dst->src[0]->type == GGML_TYPE_Q4_0) {
-                return compute_forward_q4_0(params, dst);
-            } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
-                return compute_forward_q8_0(params, dst);
+            if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
+                return compute_forward_qx(params, dst);
             } else if (dst->src[0]->type == GGML_TYPE_F16) {
                 return compute_forward_fp16(params, dst);
             }
@@ -331,204 +702,457 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         return true;
     }
 
-    bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
-        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+    bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
 
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
-        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
-        if (!kernels) {
-            return false;
-        }
-
-        bool is_gemv = src1->ne[1] == 1;
-        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
-        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
-
-        GGML_ASSERT(kernel);
-        if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
-            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
-            return false;
-        }
+        const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
+        const bool has_header = kleidiai_is_weight_header_valid(header);
+        const bool is_gemv = src1->ne[1] == 1;
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
 
-        const int ith = params->ith;
-        const int nth_raw = params->nth;
-        const int nth = nth_raw > 0 ? nth_raw : 1;
+        auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
+            if (slot_index < 0 || slot_index >= slot_total) {
+                return nullptr;
+            }
+            if (has_header) {
+                if (slot_index < header->slot_count) {
+                    size_out = static_cast<size_t>(header->sizes[slot_index]);
+                    return kleidiai_weight_slot_ptr(header, slot_index);
+                }
+                return nullptr;
+            }
+            if (slot_index == 0) {
+                size_out = ggml_nbytes(src0);
+                return static_cast<const uint8_t *>(src0->data);
+            }
+            return nullptr;
+        };
+
+        struct runtime_slot {
+            int slot_index;
+            ggml_kleidiai_kernels * kernels;
+            kernel_info * kernel;
+            lhs_packing_info * lhs_info;
+            size_t mr;
+            size_t nr;
+            size_t kr;
+            size_t sr;
+            size_t n_step;
+            size_t lhs_packed_size;
+            size_t lhs_offset;
+            size_t n_offset;
+            size_t n_cols;
+            int assigned_threads;
+            int thread_begin;
+            int thread_end;
+            const uint8_t * rhs_base;
+        };
+
+        std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};
+        int runtime_count = 0;
+
+        for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
+            ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+            kernel_info * kinfo      = is_gemv ? &kernels->gemv : &kernels->gemm;
+            lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+            if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
+                !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
+                continue;
+            }
 
-        const size_t k = ne00;
-        const size_t m = ne11;
-        const size_t n = ne01;
+            size_t rhs_size = 0;
+            const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
+            if (!rhs_ptr || rhs_size == 0) {
+                continue;
+            }
 
-        size_t mr = kernel->get_mr();
-        size_t kr = kernel->get_kr();
-        size_t sr = kernel->get_sr();
+            runtime[runtime_count] = {
+                slot,
+                kernels,
+                kinfo,
+                linfo,
+                kinfo->get_mr(),
+                kinfo->get_nr(),
+                kinfo->get_kr(),
+                kinfo->get_sr(),
+                kinfo->get_n_step(),
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                rhs_ptr
+            };
+            ++runtime_count;
+        }
 
-        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
-        uint8_t * lhs_packed       = (uint8_t*)params->wdata;
-        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+        if (runtime_count == 0) {
+            ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
+            if (!fallback) {
+                return false;
+            }
+            kernel_info * kinfo      = is_gemv ? &fallback->gemv : &fallback->gemm;
+            lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
+            rhs_packing_info * rinfo = &fallback->rhs_info;
+            if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
+                !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
+                !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
+                return false;
+            }
+            kernel_chain[0] = fallback;
+            runtime[0] = {
+                0,
+                fallback,
+                kinfo,
+                linfo,
+                kinfo->get_mr(),
+                kinfo->get_nr(),
+                kinfo->get_kr(),
+                kinfo->get_sr(),
+                kinfo->get_n_step(),
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                nullptr
+            };
+            size_t rhs_size_fallback = 0;
+            const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
+            if (!rhs_base) {
+                rhs_base = static_cast<const uint8_t *>(src0->data);
+            }
+            runtime[0].rhs_base = rhs_base;
+            runtime_count = 1;
+        }
 
-        const size_t n_step = kernel->get_n_step();
-        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
-        const size_t n_start = ith * num_n_per_thread;
+        const int nth_total = params->nth > 0 ? params->nth : 1;
+        const int ith_total = params->ith;
 
-        size_t n_to_process = 0;
-        if (n_start < n) {
-            n_to_process = num_n_per_thread;
-            if ((n_start + n_to_process) > n) {
-                n_to_process = n - n_start;
+        int sme_slot = -1;
+        for (int i = 0; i < runtime_count; ++i) {
+            if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+                sme_slot = i;
+                break;
             }
         }
 
-        // Calculate number of columns to be processed per thread
-        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
-        const size_t m_start = ith * num_m_per_thread;
-        size_t m_to_process = num_m_per_thread;
-        if ((m_start + m_to_process) > m) {
-            m_to_process = m - m_start;
+        const int sme_cap_limit = ctx.sme_thread_cap;
+        const bool use_hybrid = sme_cap_limit > 0 &&
+                                 runtime_count > 1 &&
+                                 nth_total > sme_cap_limit;
+        // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
+        // If rows are small or average columns per thread are small, keep single-slot.
+        size_t min_cols_per_thread = 0;
+        if (runtime_count > 0 && nth_total > 0) {
+            min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);
         }
+        const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
 
-        if (m_start < m) {
-            // Transform LHS
-            const size_t src_stride        = src1->nb[1];
-            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
-            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
-            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
-
-            // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
-            lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
-        }
+        const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
 
-        ggml_barrier(params->threadpool);
+        if (!hybrid_enabled) {
+            int chosen_slot = 0;
+            if (too_small_for_hybrid && sme_slot != -1) {
+                chosen_slot = sme_slot;
+            } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
+                chosen_slot = 1;
+            }
+            if (chosen_slot != 0 && chosen_slot < runtime_count) {
+                runtime[0] = runtime[chosen_slot];
+            }
+            runtime_count = runtime_count > 0 ? 1 : 0;
 
-        // Perform the operation
-        const size_t dst_stride        = dst->nb[1];
-        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
-        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
-        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
-        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
-        const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset);
-        float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+            // Recompute SME slot based on the collapsed runtime[0]
+            sme_slot = -1;
+            if (runtime_count > 0 &&
+                (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+                sme_slot = 0;
+            }
+        }
 
-        if (n_to_process > 0) {
-            kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
-                               sizeof(float), -FLT_MAX, FLT_MAX);
+        int sme_cap = kleidiai_sme_thread_cap();
+        if (sme_cap < 0) {
+            sme_cap = nth_total;
         }
+        sme_cap = std::min(sme_cap, nth_total);
 
-        return true;
-    }
+        int threads_remaining = nth_total;
+        if (sme_slot != -1) {
+            int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
+            runtime[sme_slot].assigned_threads = sme_threads;
+            threads_remaining -= sme_threads;
+        }
 
-    bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
-        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
+        int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+        int fallback_count = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            if (i == sme_slot) {
+                continue;
+            }
+            fallback_indices[fallback_count++] = i;
+        }
 
-        const ggml_tensor * src0 = dst->src[0];
-        const ggml_tensor * src1 = dst->src[1];
+        for (int fi = 0; fi < fallback_count; ++fi) {
+            if (threads_remaining <= 0) {
+                break;
+            }
+            const int slot_index = fallback_indices[fi];
+            const int slots_left = fallback_count - fi;
+            int share = (threads_remaining + slots_left - 1) / slots_left;
+            share     = std::min(share, threads_remaining);
+            runtime[slot_index].assigned_threads = share;
+            threads_remaining -= share;
+        }
 
-        GGML_TENSOR_BINARY_OP_LOCALS
+        if (threads_remaining > 0) {
+            const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
+            runtime[fallback_slot].assigned_threads += threads_remaining;
+            threads_remaining = 0;
+        }
 
-        ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
-        if (!kernels) {
-            return false;
+        int thread_cursor = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            runtime[i].thread_begin = thread_cursor;
+            thread_cursor += runtime[i].assigned_threads;
+            runtime[i].thread_end = thread_cursor;
         }
 
-        bool is_gemv = src1->ne[1] == 1;
-        kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
-        lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+        if (thread_cursor < nth_total && runtime_count > 0) {
+            runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
+            runtime[runtime_count - 1].thread_end = nth_total;
+        }
 
-        if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
-            !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
+        int local_slot = -1;
+        int local_ith  = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
+                local_slot = i;
+                local_ith  = ith_total - runtime[i].thread_begin;
+                break;
+            }
+        }
+        if (local_slot == -1) {
             return false;
         }
 
-        const int ith = params->ith;
-        const int nth_raw = params->nth;
-        const int nth = nth_raw > 0 ? nth_raw : 1;
-
         const size_t k = ne00;
         const size_t m = ne11;
         const size_t n = ne01;
 
-        size_t mr = kernel->get_mr();
-        size_t kr = kernel->get_kr();
-        size_t sr = kernel->get_sr();
-
-        const uint8_t * lhs        = static_cast<const uint8_t *>(src1->data);
-        uint8_t * lhs_packed       = static_cast<uint8_t *>(params->wdata);
-        const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+        size_t cursor = 0;
+        for (int i = 0; i < runtime_count; ++i) {
+            const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
+            const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                              slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
+            runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
+            cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+            runtime[i].lhs_offset = cursor;
+            cursor += runtime[i].lhs_packed_size;
+        }
 
-        const size_t n_step = kernel->get_n_step();
-        const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
-        const size_t n_start = ith * num_n_per_thread;
+        GGML_ASSERT(cursor <= params->wsize);
+        uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
 
-        size_t n_to_process = 0;
-        if (n_start < n) {
-            n_to_process = num_n_per_thread;
-            if ((n_start + n_to_process) > n) {
-                n_to_process = n - n_start;
+        size_t assigned_cols = 0;
+        uint64_t weighted_total = 0;
+        if (runtime_count > 1 && sme_slot != -1) {
+            for (int i = 0; i < runtime_count; ++i) {
+                const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
+                weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
             }
         }
+        for (int i = 0; i < runtime_count; ++i) {
+            runtime[i].n_offset = assigned_cols;
+            if (runtime[i].assigned_threads == 0) {
+                runtime[i].n_cols = 0;
+                continue;
+            }
+            const size_t remaining_cols = n - assigned_cols;
+            if (remaining_cols == 0) {
+                runtime[i].n_cols = 0;
+                continue;
+            }
+            const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
+            size_t target      = 0;
+            if (weighted_total > 0) {
+                const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
+                target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
+            } else {
+                target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
+            }
+            target             = std::min(target, remaining_cols);
+            size_t aligned     = round_down(target, step);
+            if (aligned == 0 && remaining_cols >= step) {
+                aligned = step;
+            }
+            runtime[i].n_cols = aligned;
+            assigned_cols += aligned;
+        }
 
-        const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
-        const size_t m_start = ith * num_m_per_thread;
-        size_t m_to_process = num_m_per_thread;
-        if ((m_start + m_to_process) > m) {
-            m_to_process = m - m_start;
+        if (assigned_cols < n) {
+            for (int i = runtime_count - 1; i >= 0; --i) {
+                if (runtime[i].assigned_threads > 0) {
+                    runtime[i].n_cols += n - assigned_cols;
+                    break;
+                }
+            }
         }
+        const size_t dst_stride = dst->nb[1];
 
-        if (m_start < m) {
-            const size_t src_stride        = src1->nb[1];
-            const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
-            const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
-            void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset);
+        for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
+            const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
+            uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
 
-            lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
-        }
+            if (runtime[local_slot].assigned_threads > 0) {
+                runtime_slot & slot = runtime[local_slot];
+                const ggml_type slot_rhs_type = slot.kernels->rhs_type;
+                const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                                 slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+                const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
+                int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
+                max_threads = std::max<int64_t>(1, max_threads);
+                const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);
 
-        ggml_barrier(params->threadpool);
+                if (local_ith < use_threads) {
+                    const int64_t num_m_per_thread0   = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
+                    const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
 
-        const size_t dst_stride        = dst->nb[1];
-        const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
-        const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
-        const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride);
-        const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset);
-        const void * lhs_ptr           = static_cast<const void *>(lhs_packed + lhs_packed_offset);
-        float * dst_ptr                = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+                    const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
+                    const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
+
+                    const size_t base_packed_off  = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+                    const size_t next_block_off   = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+                    const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
+
+                    int64_t remaining = m_count;
+                    int64_t cur       = m_start;
+
+                    uint8_t * lhs_packed = scratch + slot.lhs_offset;
+                    while (remaining > 0) {
+                        const int64_t row_in_group = cur;
+                        const int64_t avail        = (int64_t)m - row_in_group;
+                        const int64_t take         = std::min(avail, remaining);
+
+                        const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
+                        const void * src_ptr = lhs_batch_base + src_off;
+                        const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
+                        void * dst_ptr       = lhs_packed + dst_off;
+
+                        slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
+
+                        cur       += take;
+                        remaining -= take;
+                    }
+                }
+            }
+
+            ggml_barrier(params->threadpool);
 
-        if (n_to_process > 0) {
-            kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
-                                  sizeof(float), -FLT_MAX, FLT_MAX);
+            runtime_slot & slot = runtime[local_slot];
+            if (slot.n_cols > 0 && slot.assigned_threads > 0) {
+                int64_t active_threads = slot.assigned_threads;
+                const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
+                if (max_threads > 0) {
+                    active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));
+                }
+                active_threads = std::max<int64_t>(1, active_threads);
+
+                if (local_ith < active_threads) {
+                    const size_t step = slot.n_step ? slot.n_step : 1;
+                    const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
+                    const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
+                    const size_t local_start = (size_t)local_ith * chunk0;
+                    const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
+
+                    if (cols > 0) {
+                        const ggml_type slot_rhs_type = slot.kernels->rhs_type;
+                        const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                                         slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+                        const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                                          slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+                        const size_t global_start = slot.n_offset + local_start;
+                        const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+                        const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
+                        const size_t dst_offset        = slot.kernel->get_dst_offset(0, global_start, dst_stride);
+
+                        const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
+                        const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
+                        float * dst_ptr         = reinterpret_cast<float *>(dst_batch_base + dst_offset);
+
+                        slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
+                                                   lhs_ptr,
+                                                   rhs_ptr,
+                                                   dst_ptr,
+                                                   dst_stride,
+                                                   sizeof(float),
+                                                   -FLT_MAX,
+                                                   FLT_MAX);
+                    }
+                }
+            }
+
+            if (batch_idx != ne12 - 1) {
+                ggml_barrier(params->threadpool);
+            }
         }
 
         return true;
     }
 
     bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
 
         GGML_TENSOR_BINARY_OP_LOCALS
 
+        const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
+        const bool has_header = kleidiai_is_weight_header_valid(header);
+
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
+        const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+                                        : kleidiai_collect_q4_chain(kernel_chain);
+
         ggml_kleidiai_kernels * kernels = nullptr;
-        size_t block_len = 0;
-        size_t num_bytes_multiplier = 0;
+        const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);
 
-        if (dst->src[0]->type == GGML_TYPE_Q4_0) {
-            if (!ctx.kernels_q4) {
-                return false;
+        if (has_header && chain_count > 0) {
+            int select_slot = 0;
+            if (select_slot >= header->slot_count) {
+                select_slot = header->slot_count - 1;
             }
-            kernels = ctx.kernels_q4;
-            block_len = QK4_0;
-            num_bytes_multiplier = sizeof(uint16_t);
-        } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
-            if (!ctx.kernels_q8) {
-                return false;
+            if (select_slot >= 0 && select_slot < chain_count) {
+                kernels = kernel_chain[select_slot];
+                const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
+                if (slot_ptr) {
+                    packed_base = slot_ptr;
+                }
             }
-            kernels = ctx.kernels_q8;
-            block_len = QK8_0;
-            num_bytes_multiplier = sizeof(float);
-        } else {
+        }
+
+        if (!kernels && chain_count > 0) {
+            kernels = kernel_chain[0];
+            if (has_header) {
+                const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
+                if (slot_ptr) {
+                    packed_base = slot_ptr;
+                }
+            }
+        }
+
+        if (!kernels) {
             return false;
         }
 
@@ -541,6 +1165,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         const int64_t nc     = ne00;
         const int64_t nr     = ggml_nelements(src1);
 
+        const ggml_type rhs_type = kernels->rhs_type;
+        size_t block_len = 0;
+        size_t num_bytes_multiplier = 0;
+        if (rhs_type == GGML_TYPE_Q4_0) {
+            block_len = QK4_0;
+            num_bytes_multiplier = sizeof(uint16_t);
+        } else if (rhs_type == GGML_TYPE_Q8_0) {
+            block_len = QK8_0;
+            num_bytes_multiplier = sizeof(float);
+        } else {
+            return false;
+        }
+
         const size_t block_rows = kernel->get_nr();
         const size_t kr         = kernel->get_kr();
 
@@ -559,7 +1196,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
             GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
 
             float *out = (float *)((char *)dst->data + i * nb1);
-            rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
+            rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
         }
 
         return true;
@@ -567,36 +1204,39 @@ class tensor_traits : public ggml::cpu::tensor_traits {
 
 public:
     int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
+        GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
         const size_t n = tensor->ne[1];
         const size_t k = tensor->ne[0];
 
-        if (tensor->type == GGML_TYPE_Q4_0) {
-            if (!ctx.kernels_q4) {
-                return -1;
-            }
-            size_t nr = ctx.kernels_q4->gemm.get_nr();
-            size_t kr = ctx.kernels_q4->gemm.get_kr();
-            size_t sr = ctx.kernels_q4->gemm.get_sr();
+        kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
+        if (!header) {
+            return -1;
+        }
 
-            struct kai_rhs_pack_qs4cxs1s0_param params;
-            params.lhs_zero_point = 1;
-            params.rhs_zero_point = 8;
-            ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
-                                                  static_cast<const uint8_t *>(data),
-                                                  nullptr, nullptr, tensor->data, 0, &params);
-            GGML_UNUSED(data_size);
-            return 0;
-        } else if (tensor->type == GGML_TYPE_Q8_0) {
-            if (!ctx.kernels_q8) {
-                return -1;
-            }
+        header->magic      = GGML_KLEIDIAI_PACK_MAGIC;
+        header->version    = GGML_KLEIDIAI_PACK_VERSION;
+        header->slot_count = 0;
+
+        uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);
+        size_t cursor = sizeof(kleidiai_weight_header);
+        cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
+        const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+                                       : kleidiai_collect_q4_chain(kernel_chain);
+        const bool allow_fallback = kleidiai_pack_fallback_allowed();
+
+        std::vector<int8_t> qdata;
+        std::vector<float>  scales;
+
+        if (want_q8 && slot_total > 0) {
+            qdata.resize(n * k, 0);
+            scales.resize(n, 0.0f);
 
             const size_t row_stride = tensor->nb[1];
             const size_t k_blocks   = (k + QK8_0 - 1) / QK8_0;
 
-            std::vector<int8_t> qdata(n * k, 0);
-            std::vector<float> scales(n, 0.0f);
-
             for (size_t row = 0; row < n; ++row) {
                 const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
                     static_cast<const uint8_t *>(data) + row * row_stride);
@@ -610,7 +1250,7 @@ public:
                         if (linear_idx >= k) {
                             break;
                         }
-                        const float value = d * blk.qs[l];
+                        const float value = d * static_cast<float>(blk.qs[l]);
                         max_abs = std::max(max_abs, std::fabs(value));
                     }
                 }
@@ -627,31 +1267,73 @@ public:
                         if (linear_idx >= k) {
                             break;
                         }
-                        const float value = d * blk.qs[l];
+                        const float value = d * static_cast<float>(blk.qs[l]);
                         int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
                         q = std::clamp(q, -127, 127);
                         qdata[row * k + linear_idx] = static_cast<int8_t>(q);
                     }
                 }
             }
+        }
+
+        for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
+            if (!allow_fallback && slot > 0) {
+                break;
+            }
+            ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+            kernel_info * kernel = &kernels->gemm;
+            rhs_packing_info * rhs_info = &kernels->rhs_info;
+            if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
+                continue;
+            }
+
+            const size_t nr = kernel->get_nr();
+            const size_t kr = kernel->get_kr();
+            const size_t sr = kernel->get_sr();
+            const ggml_type rhs_type = kernels->rhs_type;
+            const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
+                                     rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
+            if (block_len == 0) {
+                continue;
+            }
 
-            size_t nr = ctx.kernels_q8->gemm.get_nr();
-            size_t kr = ctx.kernels_q8->gemm.get_kr();
-            size_t sr = ctx.kernels_q8->gemm.get_sr();
+            const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
+            const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+            uint8_t * dst_ptr = base_ptr + aligned_cursor;
+
+            if (rhs_type == GGML_TYPE_Q4_0) {
+                struct kai_rhs_pack_qs4cxs1s0_param params;
+                params.lhs_zero_point = 1;
+                params.rhs_zero_point = 8;
+                rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
+                                       static_cast<const uint8_t *>(data), nullptr, nullptr,
+                                       dst_ptr, 0, &params);
+            } else if (rhs_type == GGML_TYPE_Q8_0) {
+                struct kai_rhs_pack_qsi8cx_params params;
+                params.lhs_zero_point = 1;
+                params.scale_multiplier = 1.0f;
+                rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
+                                       qdata.data(), nullptr, scales.data(),
+                                       dst_ptr, 0, &params);
+            } else {
+                continue;
+            }
+
+            header->offsets[header->slot_count] = aligned_cursor;
+            header->sizes[header->slot_count]   = packed_size;
+            ++header->slot_count;
 
-            struct kai_rhs_pack_qsi8cx_params params;
-            params.lhs_zero_point = 1;
-            params.scale_multiplier = 1.0f;
+            cursor = aligned_cursor + packed_size;
+        }
 
-            ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
-                                                  qdata.data(), nullptr, scales.data(),
-                                                  tensor->data, 0, &params);
-            GGML_UNUSED(data_size);
-            return 0;
+        if (header->slot_count == 0) {
+            header->magic   = 0;
+            header->version = 0;
+            memcpy(tensor->data, data, data_size);
         }
 
-        GGML_UNUSED(data_size);
-        return -1;
+        return 0;
     }
 };
 
@@ -681,9 +1363,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu
 }
 
 static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
-    return "CPU_KLEIDIAI";
-
     GGML_UNUSED(buft);
+    return "CPU_KLEIDIAI";
 }
 
 static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -702,49 +1383,78 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(
 }
 
 static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return TENSOR_ALIGNMENT;
-
     GGML_UNUSED(buft);
+    return TENSOR_ALIGNMENT;
 }
 
 static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
     GGML_UNUSED(buft);
 
+    if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
+        return ggml_nbytes(tensor);
+    }
+
     const size_t n = tensor->ne[1];
     const size_t k = tensor->ne[0];
 
-    ggml_kleidiai_kernels * kernels = nullptr;
-    size_t block_len = 0;
-
-    if (tensor->type == GGML_TYPE_Q4_0) {
-        GGML_ASSERT(ctx.kernels_q4);
-        kernels = ctx.kernels_q4;
-        block_len = QK4_0;
-    } else if (tensor->type == GGML_TYPE_Q8_0) {
-        GGML_ASSERT(ctx.kernels_q8);
-        kernels = ctx.kernels_q8;
-        block_len = QK8_0;
-    } else {
-        return 0;
+    size_t cursor = sizeof(kleidiai_weight_header);
+    cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+    std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+    const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
+    const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+                                   : kleidiai_collect_q4_chain(kernel_chain);
+    const bool allow_fallback = kleidiai_pack_fallback_allowed();
+
+    size_t slot_count = 0;
+    for (int slot = 0; slot < slot_total; ++slot) {
+        if (!allow_fallback && slot > 0) {
+            break;
+        }
+        ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+        if (!kernels) {
+            continue;
+        }
+        kernel_info * kernel = &kernels->gemm;
+        rhs_packing_info * rhs_info = &kernels->rhs_info;
+        if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
+            continue;
+        }
+
+        const ggml_type rhs_type = kernels->rhs_type;
+        const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+                                 rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
+        if (block_len == 0) {
+            continue;
+        }
+
+        cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+        cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
+        ++slot_count;
     }
 
-    const size_t nr = kernels->gemm.get_nr();
-    const size_t kr = kernels->gemm.get_kr();
-    const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
-    const size_t raw     = ggml_nbytes(tensor);
+    if (slot_count == 0) {
+        return ggml_nbytes(tensor);
+    }
 
-    return packed > raw ? packed : raw;
+    return std::max(cursor, ggml_nbytes(tensor));
 }
 
 namespace ggml::cpu::kleidiai {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+        std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+        const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
         if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
             (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
             op->src[0]->buffer &&
             (ggml_n_dims(op->src[0]) == 2) &&
-            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
-            if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
+            op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
+            slot_total > 0) {
+            if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
+                return false;
+            }
+            if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
                 return false;
             }
             if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
@@ -762,14 +1472,17 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
         if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
             if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
                 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
-            }
-            else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
-                if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
-                    (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
-                    return nullptr;
+            } else {
+                std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+                const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
+                const bool has_kernel = slot_total > 0;
+                if (has_kernel && op->src[1]->ne[1] > 1) {
+                    if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
+                        (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
+                        return nullptr;
+                    }
+                    return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
                 }
-
-                return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
             }
         }
         return nullptr;