-// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
#include <arm_neon.h>
#include <assert.h>
+#include <stdio.h>
#include <atomic>
#include <cfloat>
-#include <cmath>
#include <algorithm>
+#include <cmath>
#include <stdexcept>
#include <stdint.h>
#include <string.h>
#include <string>
#include <vector>
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <fstream>
+#include <set>
+#include <iostream>
+#include <climits>
#if defined(__linux__)
#include <asm/hwcap.h>
#include <sys/auxv.h>
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
#elif defined(__APPLE__)
#include <string_view>
#include <sys/sysctl.h>
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"
+static constexpr int GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2;
+static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC = 0x4b4c4149; // "KLAI"
+static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION = 1;
+static constexpr size_t GGML_KLEIDIAI_PACK_ALIGN = 64;
+
struct ggml_kleidiai_context {
cpu_feature features;
ggml_kleidiai_kernels * kernels_q4;
ggml_kleidiai_kernels * kernels_q8;
-} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
+ int sme_thread_cap; // <= 0 means “SME disabled/unknown”;
+ int thread_hint; // <= 0 means “no hint”
+} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 };
static const char* cpu_feature_to_string(cpu_feature f) {
if (f == CPU_FEATURE_NONE) {
}
}
-static void init_kleidiai_context(void) {
+static size_t detect_num_smcus() {
+ if (!ggml_cpu_has_sme()) {
+ return 0;
+ }
+
+#if defined(__linux__) && defined(__aarch64__)
+ // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs.
+ size_t num_private = 0;
+ std::set<uint32_t> shared_ids;
+
+ for (size_t cpu = 0;; ++cpu) {
+ const std::string path =
+ "/sys/devices/system/cpu/cpu" + std::to_string(cpu) +
+ "/regs/identification/smidr_el1";
+
+ std::ifstream file(path);
+ if (!file.is_open()) {
+ break;
+ }
+
+ uint64_t smidr = 0;
+ if (!(file >> std::hex >> smidr)) {
+ continue;
+ }
+
+ // Arm ARM: SMIDR_EL1
+ const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3);
+ // Build an "affinity-like" identifier for shared SMCUs.
+ // Keep the original packing logic, but isolate it here.
+ const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u));
+
+ switch (sh) {
+ case 0b10: // private SMCU
+ ++num_private;
+ break;
+ case 0b11: // shared SMCU
+ shared_ids.emplace(id);
+ break;
+ case 0b00:
+ // Ambiguous / implementation-defined. Be conservative:
+ // treat id==0 as private, otherwise as shared.
+ if (id == 0) ++num_private;
+ else shared_ids.emplace(id);
+ break;
+ default:
+ break;
+ }
+ }
+
+ return num_private + shared_ids.size();
+
+#elif defined(__APPLE__) && defined(__aarch64__)
+ // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>.
+ char chip_name[256] = {};
+ size_t size = sizeof(chip_name);
+
+ if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) {
+ const std::string brand(chip_name);
+
+ struct ModelSMCU { const char *match; size_t smcus; };
+ static const ModelSMCU table[] = {
+ { "M4 Ultra", 2 },
+ { "M4 Max", 2 },
+ { "M4 Pro", 2 },
+ { "M4", 1 },
+ };
+ for (const auto &e : table) {
+ if (brand.find(e.match) != std::string::npos) {
+ return e.smcus;
+ }
+ }
+ }
+ return 1;
+
+#else
+ return 1;
+#endif
+}
+
+static int parse_uint_env(const char *s, const char *name, bool *ok) {
+ if (!s) { *ok = false; return 0; }
+ char *end = nullptr;
+ long v = strtol(s, &end, 10);
+ if (end == s || *end != '\0') {
+ GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s);
+ *ok = false;
+ return 0;
+ }
+ if (v < 0 || v > INT_MAX) {
+ GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s);
+ *ok = false;
+ return 0;
+ }
+ *ok = true;
+ return (int)v;
+}
+
+static void init_kleidiai_context(void) {
ggml_critical_section_start();
static bool initialized = false;
if (!initialized) {
initialized = true;
- const char *env_var = getenv("GGML_KLEIDIAI_SME");
- int sme_enabled = 0;
+
+ const char *env_sme = getenv("GGML_KLEIDIAI_SME");
+ const char *env_threads = getenv("GGML_TOTAL_THREADS");
+
+ const bool cpu_has_sme = ggml_cpu_has_sme();
+ size_t detected_smcus = 0;
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
- if (env_var) {
- sme_enabled = atoi(env_var);
+ if (env_threads) {
+ bool ok = false;
+ int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok);
+ if (ok && hint > 0) {
+ ctx.thread_hint = hint;
+ }
}
- if (sme_enabled != 0) {
- ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
+ // SME policy:
+ // - If CPU doesn't support SME: SME always off.
+ // - Else:
+ // - env unset => auto-detect cores; enable if detected > 0.
+ // - env=0 => force off.
+ // - env>0 => force N cores (skip detection).
+ int sme_cores = 0;
+ bool sme_env_ok = false;
+ bool sme_env_set = (env_sme != nullptr);
+
+ if (!cpu_has_sme) {
+ if (sme_env_set) {
+ bool ok = false;
+ int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
+ if (ok && req > 0) {
+ GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req);
+ }
+ }
+ sme_cores = 0;
+ } else {
+ if (sme_env_set) {
+ bool ok = false;
+ int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok);
+ sme_env_ok = ok;
+
+ if (!ok) {
+ GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n");
+ detected_smcus = detect_num_smcus();
+ sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
+ } else if (v == 0) {
+ sme_cores = 0;
+ } else {
+ sme_cores = v;
+ }
+ } else {
+ detected_smcus = detect_num_smcus();
+ sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0;
+ }
+
+ if (!sme_env_set && sme_cores == 0) {
+ GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n");
+ }
+
+ if (sme_cores > 0) {
+ ctx.features |= CPU_FEATURE_SME;
+ }
}
+
+ // Kernel selection
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
-#ifndef NDEBUG
- if (ctx.kernels_q4) {
- GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
+
+ if (!ctx.kernels_q4) {
+ GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features);
+ } else {
+ GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
+ }
+
+ if (!ctx.kernels_q8) {
+ GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features);
+ } else {
+ GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
}
- if (ctx.kernels_q8) {
- GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
+
+ ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0;
+
+ if (ctx.features & CPU_FEATURE_SME) {
+ if (sme_env_set && sme_env_ok && sme_cores > 0) {
+ GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores);
+ } else {
+ GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores);
+ }
+ } else {
+ GGML_LOG_INFO("kleidiai: SME disabled\n");
}
-#endif
}
+
ggml_critical_section_end();
}
+static inline int kleidiai_sme_thread_cap() {
+ return ctx.sme_thread_cap;
+}
+
+static inline size_t align_up(size_t value, size_t alignment) {
+ if (alignment == 0) {
+ return value;
+ }
+ const size_t remainder = value % alignment;
+ return remainder == 0 ? value : value + (alignment - remainder);
+}
+
+static inline bool kleidiai_pack_fallback_allowed() {
+ if (ctx.sme_thread_cap <= 0) {
+ return false;
+ }
+ if (ctx.thread_hint <= 0) {
+ return true;
+ }
+ return ctx.thread_hint > ctx.sme_thread_cap;
+}
+
+struct kleidiai_weight_header {
+ uint32_t magic;
+ uint16_t version;
+ uint16_t slot_count;
+ uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+ uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+};
+
+static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) {
+ return reinterpret_cast<kleidiai_weight_header *>(data);
+}
+
+static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) {
+ return reinterpret_cast<const kleidiai_weight_header *>(data);
+}
+
+static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) {
+ if (!header) {
+ return false;
+ }
+ if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) {
+ return false;
+ }
+ if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) {
+ return false;
+ }
+ return true;
+}
+
+static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) {
+ if (!kleidiai_is_weight_header_valid(header)) {
+ return nullptr;
+ }
+ if (slot < 0 || slot >= header->slot_count) {
+ return nullptr;
+ }
+ return reinterpret_cast<uint8_t *>(header) + header->offsets[slot];
+}
+
+static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) {
+ if (!kleidiai_is_weight_header_valid(header)) {
+ return nullptr;
+ }
+ if (slot < 0 || slot >= header->slot_count) {
+ return nullptr;
+ }
+ return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot];
+}
+
+static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() {
+ return ctx.kernels_q4;
+}
+
+static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() {
+ return ctx.kernels_q8;
+}
+
+template <typename SelectFallback>
+static int kleidiai_collect_kernel_chain_common(
+ ggml_kleidiai_kernels * primary,
+ cpu_feature features,
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out,
+ SelectFallback select_fallback) {
+ int count = 0;
+ if (!primary) {
+ return 0;
+ }
+ out[count++] = primary;
+
+ if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+ const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME);
+ if (fallback_mask != CPU_FEATURE_NONE) {
+ ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask);
+ if (fallback && fallback != primary &&
+ fallback->lhs_type == primary->lhs_type &&
+ fallback->rhs_type == primary->rhs_type &&
+ fallback->op_type == primary->op_type) {
+ out[count++] = fallback;
+ }
+ }
+ }
+
+ return count;
+}
+
+static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op,
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
+ ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op);
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); });
+}
+
+static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
+ ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4();
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); });
+}
+
+static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) {
+ ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8();
+ return kleidiai_collect_kernel_chain_common(primary, ctx.features, out,
+ [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); });
+}
+
static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS);
return tensor->ne[dim];
if (op->op != GGML_OP_MUL_MAT) {
return false;
}
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
- if (!kernels) {
+
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+ const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain);
+ if (slot_count == 0) {
return false;
}
- bool is_gemv = op->src[1]->ne[1] == 1;
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
- size_t k = op->src[0]->ne[0];
- size_t n = op->src[0]->ne[1];
- size_t m = op->src[1]->ne[1];
-
- size_t mr = kernel->get_mr();
- size_t kr = kernel->get_kr();
- size_t sr = kernel->get_sr();
-
- if (kernels->rhs_type == GGML_TYPE_Q4_0) {
- if (!lhs_info->packed_size_ex) return false;
- size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
- } else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
- if (!lhs_info->packed_size_ex) return false;
- size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
- } else if (kernels->rhs_type == GGML_TYPE_F16) {
- if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
+ const bool is_gemv = op->src[1]->ne[1] == 1;
+ const size_t k = op->src[0]->ne[0];
+ const size_t n = op->src[0]->ne[1];
+ const size_t m = op->src[1]->ne[1];
+
+ if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) {
+ const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0;
+
+ size_t cursor = 0;
+ bool any_slot = false;
+
+ for (int slot = 0; slot < slot_count; ++slot) {
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+
+ if (!lhs_info || !lhs_info->packed_size_ex || !kernel) {
+ return false;
+ }
+
+ const size_t mr = kernel->get_mr();
+ const size_t kr = kernel->get_kr();
+ const size_t sr = kernel->get_sr();
+
+ const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr);
+
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+ cursor += packed;
+ any_slot = true;
+ }
+
+ if (!any_slot) {
+ return false;
+ }
+
+ size = cursor;
+ return true;
+ }
+
+ if (op->src[0]->type == GGML_TYPE_F16) {
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
+ GGML_ASSERT(rhs_batch_size0 > 0);
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
- size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
- kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
- k * n * sizeof(float) + n * sizeof(float);
- } else {
- return false;
+
+ size_t cursor = 0;
+ bool any_slot = false;
+
+ for (int slot = 0; slot < slot_count; ++slot) {
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+ lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+ if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) {
+ return false;
+ }
+
+ const size_t mr = kernel->get_mr();
+ const size_t kr = kernel->get_kr();
+ const size_t sr = kernel->get_sr();
+
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+ cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr);
+ any_slot = true;
+ }
+
+ for (int slot = 0; slot < slot_count; ++slot) {
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+ kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
+ if (!kernel || !kernels->rhs_info.packed_size_ex) {
+ return false;
+ }
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+ cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0);
+ }
+
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+ cursor += k * n * sizeof(float);
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+ cursor += n * sizeof(float);
+
+ if (!any_slot) {
+ return false;
+ }
+
+ size = cursor;
+ return true;
}
- return true;
+ return false;
}
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
if (dst->op == GGML_OP_MUL_MAT) {
- if (dst->src[0]->type == GGML_TYPE_Q4_0) {
- return compute_forward_q4_0(params, dst);
- } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
- return compute_forward_q8_0(params, dst);
+ if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
+ return compute_forward_qx(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_fp16(params, dst);
}
return true;
}
- bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+ bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
- if (!kernels) {
- return false;
- }
-
- bool is_gemv = src1->ne[1] == 1;
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
-
- GGML_ASSERT(kernel);
- if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
- !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
- return false;
- }
+ const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
+ const bool has_header = kleidiai_is_weight_header_valid(header);
+ const bool is_gemv = src1->ne[1] == 1;
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+ const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain);
- const int ith = params->ith;
- const int nth_raw = params->nth;
- const int nth = nth_raw > 0 ? nth_raw : 1;
+ auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * {
+ if (slot_index < 0 || slot_index >= slot_total) {
+ return nullptr;
+ }
+ if (has_header) {
+ if (slot_index < header->slot_count) {
+ size_out = static_cast<size_t>(header->sizes[slot_index]);
+ return kleidiai_weight_slot_ptr(header, slot_index);
+ }
+ return nullptr;
+ }
+ if (slot_index == 0) {
+ size_out = ggml_nbytes(src0);
+ return static_cast<const uint8_t *>(src0->data);
+ }
+ return nullptr;
+ };
+
+ struct runtime_slot {
+ int slot_index;
+ ggml_kleidiai_kernels * kernels;
+ kernel_info * kernel;
+ lhs_packing_info * lhs_info;
+ size_t mr;
+ size_t nr;
+ size_t kr;
+ size_t sr;
+ size_t n_step;
+ size_t lhs_packed_size;
+ size_t lhs_offset;
+ size_t n_offset;
+ size_t n_cols;
+ int assigned_threads;
+ int thread_begin;
+ int thread_end;
+ const uint8_t * rhs_base;
+ };
+
+ std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{};
+ int runtime_count = 0;
+
+ for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+ kernel_info * kinfo = is_gemv ? &kernels->gemv : &kernels->gemm;
+ lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+ if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset ||
+ !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) {
+ continue;
+ }
- const size_t k = ne00;
- const size_t m = ne11;
- const size_t n = ne01;
+ size_t rhs_size = 0;
+ const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size);
+ if (!rhs_ptr || rhs_size == 0) {
+ continue;
+ }
- size_t mr = kernel->get_mr();
- size_t kr = kernel->get_kr();
- size_t sr = kernel->get_sr();
+ runtime[runtime_count] = {
+ slot,
+ kernels,
+ kinfo,
+ linfo,
+ kinfo->get_mr(),
+ kinfo->get_nr(),
+ kinfo->get_kr(),
+ kinfo->get_sr(),
+ kinfo->get_n_step(),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ rhs_ptr
+ };
+ ++runtime_count;
+ }
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
- uint8_t * lhs_packed = (uint8_t*)params->wdata;
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+ if (runtime_count == 0) {
+ ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst);
+ if (!fallback) {
+ return false;
+ }
+ kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm;
+ lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info;
+ rhs_packing_info * rinfo = &fallback->rhs_info;
+ if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex ||
+ !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset ||
+ !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) {
+ return false;
+ }
+ kernel_chain[0] = fallback;
+ runtime[0] = {
+ 0,
+ fallback,
+ kinfo,
+ linfo,
+ kinfo->get_mr(),
+ kinfo->get_nr(),
+ kinfo->get_kr(),
+ kinfo->get_sr(),
+ kinfo->get_n_step(),
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ nullptr
+ };
+ size_t rhs_size_fallback = 0;
+ const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback);
+ if (!rhs_base) {
+ rhs_base = static_cast<const uint8_t *>(src0->data);
+ }
+ runtime[0].rhs_base = rhs_base;
+ runtime_count = 1;
+ }
- const size_t n_step = kernel->get_n_step();
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
- const size_t n_start = ith * num_n_per_thread;
+ const int nth_total = params->nth > 0 ? params->nth : 1;
+ const int ith_total = params->ith;
- size_t n_to_process = 0;
- if (n_start < n) {
- n_to_process = num_n_per_thread;
- if ((n_start + n_to_process) > n) {
- n_to_process = n - n_start;
+ int sme_slot = -1;
+ for (int i = 0; i < runtime_count; ++i) {
+ if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+ sme_slot = i;
+ break;
}
}
- // Calculate number of columns to be processed per thread
- const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
- const size_t m_start = ith * num_m_per_thread;
- size_t m_to_process = num_m_per_thread;
- if ((m_start + m_to_process) > m) {
- m_to_process = m - m_start;
+ const int sme_cap_limit = ctx.sme_thread_cap;
+ const bool use_hybrid = sme_cap_limit > 0 &&
+ runtime_count > 1 &&
+ nth_total > sme_cap_limit;
+ // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates.
+ // If rows are small or average columns per thread are small, keep single-slot.
+ size_t min_cols_per_thread = 0;
+ if (runtime_count > 0 && nth_total > 0) {
+ min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total);
}
+ const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128);
- if (m_start < m) {
- // Transform LHS
- const size_t src_stride = src1->nb[1];
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
-
- // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
- lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
- }
+ const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid;
- ggml_barrier(params->threadpool);
+ if (!hybrid_enabled) {
+ int chosen_slot = 0;
+ if (too_small_for_hybrid && sme_slot != -1) {
+ chosen_slot = sme_slot;
+ } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) {
+ chosen_slot = 1;
+ }
+ if (chosen_slot != 0 && chosen_slot < runtime_count) {
+ runtime[0] = runtime[chosen_slot];
+ }
+ runtime_count = runtime_count > 0 ? 1 : 0;
- // Perform the operation
- const size_t dst_stride = dst->nb[1];
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
- const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
- float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+ // Recompute SME slot based on the collapsed runtime[0]
+ sme_slot = -1;
+ if (runtime_count > 0 &&
+ (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
+ sme_slot = 0;
+ }
+ }
- if (n_to_process > 0) {
- kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
- sizeof(float), -FLT_MAX, FLT_MAX);
+ int sme_cap = kleidiai_sme_thread_cap();
+ if (sme_cap < 0) {
+ sme_cap = nth_total;
}
+ sme_cap = std::min(sme_cap, nth_total);
- return true;
- }
+ int threads_remaining = nth_total;
+ if (sme_slot != -1) {
+ int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining);
+ runtime[sme_slot].assigned_threads = sme_threads;
+ threads_remaining -= sme_threads;
+ }
- bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
+ int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS];
+ int fallback_count = 0;
+ for (int i = 0; i < runtime_count; ++i) {
+ if (i == sme_slot) {
+ continue;
+ }
+ fallback_indices[fallback_count++] = i;
+ }
- const ggml_tensor * src0 = dst->src[0];
- const ggml_tensor * src1 = dst->src[1];
+ for (int fi = 0; fi < fallback_count; ++fi) {
+ if (threads_remaining <= 0) {
+ break;
+ }
+ const int slot_index = fallback_indices[fi];
+ const int slots_left = fallback_count - fi;
+ int share = (threads_remaining + slots_left - 1) / slots_left;
+ share = std::min(share, threads_remaining);
+ runtime[slot_index].assigned_threads = share;
+ threads_remaining -= share;
+ }
- GGML_TENSOR_BINARY_OP_LOCALS
+ if (threads_remaining > 0) {
+ const int fallback_slot = (sme_slot != -1) ? sme_slot : 0;
+ runtime[fallback_slot].assigned_threads += threads_remaining;
+ threads_remaining = 0;
+ }
- ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
- if (!kernels) {
- return false;
+ int thread_cursor = 0;
+ for (int i = 0; i < runtime_count; ++i) {
+ runtime[i].thread_begin = thread_cursor;
+ thread_cursor += runtime[i].assigned_threads;
+ runtime[i].thread_end = thread_cursor;
}
- bool is_gemv = src1->ne[1] == 1;
- kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
- lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
+ if (thread_cursor < nth_total && runtime_count > 0) {
+ runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor;
+ runtime[runtime_count - 1].thread_end = nth_total;
+ }
- if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
- !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
+ int local_slot = -1;
+ int local_ith = 0;
+ for (int i = 0; i < runtime_count; ++i) {
+ if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) {
+ local_slot = i;
+ local_ith = ith_total - runtime[i].thread_begin;
+ break;
+ }
+ }
+ if (local_slot == -1) {
return false;
}
- const int ith = params->ith;
- const int nth_raw = params->nth;
- const int nth = nth_raw > 0 ? nth_raw : 1;
-
const size_t k = ne00;
const size_t m = ne11;
const size_t n = ne01;
- size_t mr = kernel->get_mr();
- size_t kr = kernel->get_kr();
- size_t sr = kernel->get_sr();
-
- const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
- uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
- const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
+ size_t cursor = 0;
+ for (int i = 0; i < runtime_count; ++i) {
+ const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type;
+ const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+ slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
+ runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr);
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+ runtime[i].lhs_offset = cursor;
+ cursor += runtime[i].lhs_packed_size;
+ }
- const size_t n_step = kernel->get_n_step();
- const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
- const size_t n_start = ith * num_n_per_thread;
+ GGML_ASSERT(cursor <= params->wsize);
+ uint8_t * scratch = static_cast<uint8_t *>(params->wdata);
- size_t n_to_process = 0;
- if (n_start < n) {
- n_to_process = num_n_per_thread;
- if ((n_start + n_to_process) > n) {
- n_to_process = n - n_start;
+ size_t assigned_cols = 0;
+ uint64_t weighted_total = 0;
+ if (runtime_count > 1 && sme_slot != -1) {
+ for (int i = 0; i < runtime_count; ++i) {
+ const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
+ weighted_total += (uint64_t)runtime[i].assigned_threads * weight;
}
}
+ for (int i = 0; i < runtime_count; ++i) {
+ runtime[i].n_offset = assigned_cols;
+ if (runtime[i].assigned_threads == 0) {
+ runtime[i].n_cols = 0;
+ continue;
+ }
+ const size_t remaining_cols = n - assigned_cols;
+ if (remaining_cols == 0) {
+ runtime[i].n_cols = 0;
+ continue;
+ }
+ const size_t step = runtime[i].n_step ? runtime[i].n_step : 1;
+ size_t target = 0;
+ if (weighted_total > 0) {
+ const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1;
+ target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total);
+ } else {
+ target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total);
+ }
+ target = std::min(target, remaining_cols);
+ size_t aligned = round_down(target, step);
+ if (aligned == 0 && remaining_cols >= step) {
+ aligned = step;
+ }
+ runtime[i].n_cols = aligned;
+ assigned_cols += aligned;
+ }
- const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
- const size_t m_start = ith * num_m_per_thread;
- size_t m_to_process = num_m_per_thread;
- if ((m_start + m_to_process) > m) {
- m_to_process = m - m_start;
+ if (assigned_cols < n) {
+ for (int i = runtime_count - 1; i >= 0; --i) {
+ if (runtime[i].assigned_threads > 0) {
+ runtime[i].n_cols += n - assigned_cols;
+ break;
+ }
+ }
}
+ const size_t dst_stride = dst->nb[1];
- if (m_start < m) {
- const size_t src_stride = src1->nb[1];
- const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
- void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
+ for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) {
+ const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
+ uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
- lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
- }
+ if (runtime[local_slot].assigned_threads > 0) {
+ runtime_slot & slot = runtime[local_slot];
+ const ggml_type slot_rhs_type = slot.kernels->rhs_type;
+ const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+ slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+ const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr);
+ int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads;
+ max_threads = std::max<int64_t>(1, max_threads);
+ const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads);
- ggml_barrier(params->threadpool);
+ if (local_ith < use_threads) {
+ const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / use_threads), slot.mr);
+ const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0;
- const size_t dst_stride = dst->nb[1];
- const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
- const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
- const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
- const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
- float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
+ const int64_t m_start = (int64_t)local_ith * num_m_per_thread0;
+ const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
+
+ const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+ const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+ const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0;
+
+ int64_t remaining = m_count;
+ int64_t cur = m_start;
+
+ uint8_t * lhs_packed = scratch + slot.lhs_offset;
+ while (remaining > 0) {
+ const int64_t row_in_group = cur;
+ const int64_t avail = (int64_t)m - row_in_group;
+ const int64_t take = std::min(avail, remaining);
+
+ const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]);
+ const void * src_ptr = lhs_batch_base + src_off;
+ const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
+ void * dst_ptr = lhs_packed + dst_off;
+
+ slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr);
+
+ cur += take;
+ remaining -= take;
+ }
+ }
+ }
+
+ ggml_barrier(params->threadpool);
- if (n_to_process > 0) {
- kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
- sizeof(float), -FLT_MAX, FLT_MAX);
+ runtime_slot & slot = runtime[local_slot];
+ if (slot.n_cols > 0 && slot.assigned_threads > 0) {
+ int64_t active_threads = slot.assigned_threads;
+ const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads;
+ if (max_threads > 0) {
+ active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads));
+ }
+ active_threads = std::max<int64_t>(1, active_threads);
+
+ if (local_ith < active_threads) {
+ const size_t step = slot.n_step ? slot.n_step : 1;
+ const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step);
+ const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0;
+ const size_t local_start = (size_t)local_ith * chunk0;
+ const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0;
+
+ if (cols > 0) {
+ const ggml_type slot_rhs_type = slot.kernels->rhs_type;
+ const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+ slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+ const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+ slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0;
+ const size_t global_start = slot.n_offset + local_start;
+ const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr);
+ const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg);
+ const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride);
+
+ const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset;
+ const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset;
+ float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
+
+ slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg,
+ lhs_ptr,
+ rhs_ptr,
+ dst_ptr,
+ dst_stride,
+ sizeof(float),
+ -FLT_MAX,
+ FLT_MAX);
+ }
+ }
+ }
+
+ if (batch_idx != ne12 - 1) {
+ ggml_barrier(params->threadpool);
+ }
}
return true;
}
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
+ const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data);
+ const bool has_header = kleidiai_is_weight_header_valid(header);
+
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+ const bool want_q8 = src0->type == GGML_TYPE_Q8_0;
+ const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+ : kleidiai_collect_q4_chain(kernel_chain);
+
ggml_kleidiai_kernels * kernels = nullptr;
- size_t block_len = 0;
- size_t num_bytes_multiplier = 0;
+ const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data);
- if (dst->src[0]->type == GGML_TYPE_Q4_0) {
- if (!ctx.kernels_q4) {
- return false;
+ if (has_header && chain_count > 0) {
+ int select_slot = 0;
+ if (select_slot >= header->slot_count) {
+ select_slot = header->slot_count - 1;
}
- kernels = ctx.kernels_q4;
- block_len = QK4_0;
- num_bytes_multiplier = sizeof(uint16_t);
- } else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
- if (!ctx.kernels_q8) {
- return false;
+ if (select_slot >= 0 && select_slot < chain_count) {
+ kernels = kernel_chain[select_slot];
+ const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot);
+ if (slot_ptr) {
+ packed_base = slot_ptr;
+ }
}
- kernels = ctx.kernels_q8;
- block_len = QK8_0;
- num_bytes_multiplier = sizeof(float);
- } else {
+ }
+
+ if (!kernels && chain_count > 0) {
+ kernels = kernel_chain[0];
+ if (has_header) {
+ const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0);
+ if (slot_ptr) {
+ packed_base = slot_ptr;
+ }
+ }
+ }
+
+ if (!kernels) {
return false;
}
const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
+ const ggml_type rhs_type = kernels->rhs_type;
+ size_t block_len = 0;
+ size_t num_bytes_multiplier = 0;
+ if (rhs_type == GGML_TYPE_Q4_0) {
+ block_len = QK4_0;
+ num_bytes_multiplier = sizeof(uint16_t);
+ } else if (rhs_type == GGML_TYPE_Q8_0) {
+ block_len = QK8_0;
+ num_bytes_multiplier = sizeof(float);
+ } else {
+ return false;
+ }
+
const size_t block_rows = kernel->get_nr();
const size_t kr = kernel->get_kr();
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
float *out = (float *)((char *)dst->data + i * nb1);
- rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
+ rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
}
return true;
public:
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
+ GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
- if (tensor->type == GGML_TYPE_Q4_0) {
- if (!ctx.kernels_q4) {
- return -1;
- }
- size_t nr = ctx.kernels_q4->gemm.get_nr();
- size_t kr = ctx.kernels_q4->gemm.get_kr();
- size_t sr = ctx.kernels_q4->gemm.get_sr();
+ kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data);
+ if (!header) {
+ return -1;
+ }
- struct kai_rhs_pack_qs4cxs1s0_param params;
- params.lhs_zero_point = 1;
- params.rhs_zero_point = 8;
- ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
- static_cast<const uint8_t *>(data),
- nullptr, nullptr, tensor->data, 0, ¶ms);
- GGML_UNUSED(data_size);
- return 0;
- } else if (tensor->type == GGML_TYPE_Q8_0) {
- if (!ctx.kernels_q8) {
- return -1;
- }
+ header->magic = GGML_KLEIDIAI_PACK_MAGIC;
+ header->version = GGML_KLEIDIAI_PACK_VERSION;
+ header->slot_count = 0;
+
+ uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data);
+ size_t cursor = sizeof(kleidiai_weight_header);
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+ const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
+ const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+ : kleidiai_collect_q4_chain(kernel_chain);
+ const bool allow_fallback = kleidiai_pack_fallback_allowed();
+
+ std::vector<int8_t> qdata;
+ std::vector<float> scales;
+
+ if (want_q8 && slot_total > 0) {
+ qdata.resize(n * k, 0);
+ scales.resize(n, 0.0f);
const size_t row_stride = tensor->nb[1];
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
- std::vector<int8_t> qdata(n * k, 0);
- std::vector<float> scales(n, 0.0f);
-
for (size_t row = 0; row < n; ++row) {
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
static_cast<const uint8_t *>(data) + row * row_stride);
if (linear_idx >= k) {
break;
}
- const float value = d * blk.qs[l];
+ const float value = d * static_cast<float>(blk.qs[l]);
max_abs = std::max(max_abs, std::fabs(value));
}
}
if (linear_idx >= k) {
break;
}
- const float value = d * blk.qs[l];
+ const float value = d * static_cast<float>(blk.qs[l]);
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
q = std::clamp(q, -127, 127);
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
}
}
}
+ }
+
+ for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) {
+ if (!allow_fallback && slot > 0) {
+ break;
+ }
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+ kernel_info * kernel = &kernels->gemm;
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
+ if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) {
+ continue;
+ }
+
+ const size_t nr = kernel->get_nr();
+ const size_t kr = kernel->get_kr();
+ const size_t sr = kernel->get_sr();
+ const ggml_type rhs_type = kernels->rhs_type;
+ const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 :
+ rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0;
+ if (block_len == 0) {
+ continue;
+ }
- size_t nr = ctx.kernels_q8->gemm.get_nr();
- size_t kr = ctx.kernels_q8->gemm.get_kr();
- size_t sr = ctx.kernels_q8->gemm.get_sr();
+ const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len);
+ const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+ uint8_t * dst_ptr = base_ptr + aligned_cursor;
+
+ if (rhs_type == GGML_TYPE_Q4_0) {
+ struct kai_rhs_pack_qs4cxs1s0_param params;
+ params.lhs_zero_point = 1;
+ params.rhs_zero_point = 8;
+ rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
+ static_cast<const uint8_t *>(data), nullptr, nullptr,
+ dst_ptr, 0, ¶ms);
+ } else if (rhs_type == GGML_TYPE_Q8_0) {
+ struct kai_rhs_pack_qsi8cx_params params;
+ params.lhs_zero_point = 1;
+ params.scale_multiplier = 1.0f;
+ rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
+ qdata.data(), nullptr, scales.data(),
+ dst_ptr, 0, ¶ms);
+ } else {
+ continue;
+ }
+
+ header->offsets[header->slot_count] = aligned_cursor;
+ header->sizes[header->slot_count] = packed_size;
+ ++header->slot_count;
- struct kai_rhs_pack_qsi8cx_params params;
- params.lhs_zero_point = 1;
- params.scale_multiplier = 1.0f;
+ cursor = aligned_cursor + packed_size;
+ }
- ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
- qdata.data(), nullptr, scales.data(),
- tensor->data, 0, ¶ms);
- GGML_UNUSED(data_size);
- return 0;
+ if (header->slot_count == 0) {
+ header->magic = 0;
+ header->version = 0;
+ memcpy(tensor->data, data, data_size);
}
- GGML_UNUSED(data_size);
- return -1;
+ return 0;
}
};
}
static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
- return "CPU_KLEIDIAI";
-
GGML_UNUSED(buft);
+ return "CPU_KLEIDIAI";
}
static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
}
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
- return TENSOR_ALIGNMENT;
-
GGML_UNUSED(buft);
+ return TENSOR_ALIGNMENT;
}
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
GGML_UNUSED(buft);
+ if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) {
+ return ggml_nbytes(tensor);
+ }
+
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
- ggml_kleidiai_kernels * kernels = nullptr;
- size_t block_len = 0;
-
- if (tensor->type == GGML_TYPE_Q4_0) {
- GGML_ASSERT(ctx.kernels_q4);
- kernels = ctx.kernels_q4;
- block_len = QK4_0;
- } else if (tensor->type == GGML_TYPE_Q8_0) {
- GGML_ASSERT(ctx.kernels_q8);
- kernels = ctx.kernels_q8;
- block_len = QK8_0;
- } else {
- return 0;
+ size_t cursor = sizeof(kleidiai_weight_header);
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+ const bool want_q8 = tensor->type == GGML_TYPE_Q8_0;
+ const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain)
+ : kleidiai_collect_q4_chain(kernel_chain);
+ const bool allow_fallback = kleidiai_pack_fallback_allowed();
+
+ size_t slot_count = 0;
+ for (int slot = 0; slot < slot_total; ++slot) {
+ if (!allow_fallback && slot > 0) {
+ break;
+ }
+ ggml_kleidiai_kernels * kernels = kernel_chain[slot];
+ if (!kernels) {
+ continue;
+ }
+ kernel_info * kernel = &kernels->gemm;
+ rhs_packing_info * rhs_info = &kernels->rhs_info;
+ if (!kernel || !rhs_info || !rhs_info->packed_size_ex) {
+ continue;
+ }
+
+ const ggml_type rhs_type = kernels->rhs_type;
+ const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 :
+ rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0;
+ if (block_len == 0) {
+ continue;
+ }
+
+ cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN);
+ cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len);
+ ++slot_count;
}
- const size_t nr = kernels->gemm.get_nr();
- const size_t kr = kernels->gemm.get_kr();
- const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
- const size_t raw = ggml_nbytes(tensor);
+ if (slot_count == 0) {
+ return ggml_nbytes(tensor);
+ }
- return packed > raw ? packed : raw;
+ return std::max(cursor, ggml_nbytes(tensor));
}
namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+ const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
- op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
- if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
+ op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() &&
+ slot_total > 0) {
+ if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) {
+ return false;
+ }
+ if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) {
return false;
}
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
- }
- else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
- if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
- (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
- return nullptr;
+ } else {
+ std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
+ const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
+ const bool has_kernel = slot_total > 0;
+ if (has_kernel && op->src[1]->ne[1] > 1) {
+ if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
+ (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
+ return nullptr;
+ }
+ return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
}
-
- return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
}
}
return nullptr;