]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
kleidiai: add support for get_rows (llama/14676)
authorCharles Xu <redacted>
Mon, 21 Jul 2025 13:49:52 +0000 (15:49 +0200)
committerGeorgi Gerganov <redacted>
Thu, 24 Jul 2025 17:57:40 +0000 (20:57 +0300)
* kleidiai: add support for get_rows

* apply fixes based on code review

* apply more fixes based on code review

src/ggml-cpu/CMakeLists.txt
src/ggml-cpu/kleidiai/kernels.cpp
src/ggml-cpu/kleidiai/kernels.h
src/ggml-cpu/kleidiai/kleidiai.cpp

index 13f745b20620cab4e2bd4a83d05cdad321883dfa..2cc42d4b02af95511144e5e220317a3e529fd844 100644 (file)
@@ -496,9 +496,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
 
         # Fetch KleidiAI sources:
         include(FetchContent)
-        set(KLEIDIAI_COMMIT_TAG "v1.9.0")
+        set(KLEIDIAI_COMMIT_TAG "v1.11.0")
         set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
-        set(KLEIDIAI_ARCHIVE_MD5  "2a8e1bb55d201557553545536489a017")
+        set(KLEIDIAI_ARCHIVE_MD5  "3fe9e5ab964c375c53839296eb71eaa2")
 
         if (POLICY CMP0135)
             cmake_policy(SET CMP0135 NEW)
index 910fd0ee4e7430f455451c1bf86eaa8e07b9e03b..ddd29d002d1ca4bbae1cb214186eb99f2d64b7b8 100644 (file)
 
 #include "kai_common.h"
 
+#include "simd-mappings.h"
+
 #include "kernels.h"
 
 #define NELEMS(x) sizeof(x) / sizeof(*x)
+
+static const size_t INT4_PER_BYTE = 2;
+static const size_t INT4_BITS     = 4;
+static const int Q4_0_ZERO_POINT  = 8;
+const size_t INT4_PER_UINT16      = 4;
+
+static void dequantize_row_qsi4c32pscalef16(
+    const void *packed_data,
+    int32_t row_idx,
+    int64_t nc,
+    float *out,
+    size_t nr_pack,
+    size_t packed_row_stride,
+    size_t kr,
+    size_t bl,
+    size_t num_bytes_multiplier
+) {
+    size_t group_idx = row_idx / nr_pack;
+    size_t row_in_group = row_idx % nr_pack;
+    const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
+    size_t num_blocks = nc / bl;
+    const uint8_t *block_ptr = packed_group;
+
+    for (size_t b = 0; b < num_blocks; ++b) {
+        uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
+        float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
+
+        const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
+        size_t num_segments = bl / kr;
+        size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
+
+        for (size_t s = 0; s < num_segments; ++s) {
+            const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
+            const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
+            for (size_t k = 0; k < num_bytes_per_segment; ++k) {
+                uint8_t byte = qbytes[k] ^ 0x88;
+                int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
+                int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
+                out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
+                out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
+            }
+        }
+        block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
+    }
+}
+
+static void dequantize_row_qsi4c32ps1s0scalef16(
+    const void *packed_data,
+    int32_t row_idx,
+    int64_t k,
+    float *out,
+    size_t nr,
+    size_t packed_row_stride,
+    size_t kr,
+    size_t bl,
+    size_t num_bytes_multiplier
+) {
+    const size_t num_blocks = k / bl;
+    const size_t bl4 = bl / INT4_PER_UINT16;
+
+    size_t group_idx = row_idx / nr;
+    size_t row_in_group = row_idx % nr;
+
+    const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
+    const uint16_t *qdata = (const uint16_t *)packed_group;
+    const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
+
+    for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
+        uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
+        float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
+
+        for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
+            uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
+
+            for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
+                int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
+                out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
+            }
+        }
+    }
+    GGML_UNUSED(kr);
+}
+
 static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
 #if defined(__ARM_FEATURE_SME)
     {
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
+            /* .to_float      = */ dequantize_row_qsi4c32ps1s0scalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_SME,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
-            /* .pack_func   = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+            /* .packed_stride = */ NULL,
+            /* .pack_func     = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
+            /* .to_float      = */ NULL,
         },
         /* .required_cpu       = */ CPU_FEATURE_SME,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
         /* .lhs_type           = */ GGML_TYPE_F32,
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
         },
         /* .rhs_info = */ {
-            /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
-            /* .pack_func   = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
+            /* .to_float      = */ dequantize_row_qsi4c32pscalef16,
         },
         /* .required_cpu       = */ CPU_FEATURE_DOTPROD,
         /* .lhs_type           = */ GGML_TYPE_F32,
index 3b268d4a22acad18d05b026881fface70ee0424a..bc8f33405d1fe1ec52755992639bf46bc37b74a5 100644 (file)
@@ -71,12 +71,15 @@ struct rhs_packing_info {
         std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
         std::function<size_t(size_t n, size_t k)>
     > packed_size;
+    size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
     std::variant<
         std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
             const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
         std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
             const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
     > pack_func;
+    void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
+          size_t kr, size_t bl, size_t num_bytes_multiplier);
 };
 
 struct ggml_kleidiai_kernels {
index fafe45e6c5c5102855cc29c95ebae2fac960f3fb..3a513a55d7654cb28d0c7061ca9ebfcc95798ec0 100644 (file)
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
     ggml_kleidiai_kernels * kernels;
 } static ctx = { CPU_FEATURE_NONE, NULL };
 
+static const char* cpu_feature_to_string(cpu_feature f) {
+    switch (f) {
+        case CPU_FEATURE_NONE:    return "NONE";
+        case CPU_FEATURE_DOTPROD: return "DOTPROD";
+        case CPU_FEATURE_I8MM:    return "I8MM";
+        case CPU_FEATURE_SVE:     return "SVE";
+        case CPU_FEATURE_SME:     return "SME";
+        default:                  return "UNKNOWN";
+    }
+}
+
 static void init_kleidiai_context(void) {
 
     ggml_critical_section_start();
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
             ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
         }
         ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
+#ifndef NDEBUG
+        if (ctx.kernels) {
+            GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
+        }
+#endif
     }
     ggml_critical_section_end();
 }
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
 
 class tensor_traits : public ggml::cpu::tensor_traits {
     bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
+        if (op->op != GGML_OP_MUL_MAT) {
+            return false;
+        }
         ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
         GGML_ASSERT(kernels);
         kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
             } else if (dst->src[0]->type == GGML_TYPE_F16) {
                 return compute_forward_kv_cache(params, dst);
             }
+        } else if (dst->op == GGML_OP_GET_ROWS) {
+            if (dst->src[0]->type == GGML_TYPE_Q4_0) {
+                return compute_forward_get_rows(params, dst);
+            }
         }
         return false;
     }
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
     }
 
     bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+
         const ggml_tensor * src0 = dst->src[0];
         const ggml_tensor * src1 = dst->src[1];
 
@@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         return true;
     }
 
+    bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
+        GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
+        GGML_ASSERT(ctx.kernels);
+
+        const ggml_tensor * src0 = dst->src[0];
+        const ggml_tensor * src1 = dst->src[1];
+
+        GGML_TENSOR_BINARY_OP_LOCALS
+
+        rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
+        kernel_info * kernel        = &ctx.kernels->gemm;
+
+        const int64_t nc     = ne00;
+        const int64_t nr     = ggml_nelements(src1);
+
+        const size_t block_rows = kernel->get_nr();
+        const size_t kr         = kernel->get_kr();
+
+        const size_t num_bytes_multiplier = sizeof(uint16_t);
+        const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
+
+        const int ith = params->ith;
+        const int nth = params->nth;
+
+        const int dr = (nr + nth - 1) / nth;
+        const int ir0 = dr * ith;
+        const int ir1 = MIN(ir0 + dr, nr);
+
+        for (int64_t i = ir0; i < ir1; ++i) {
+            GGML_ASSERT(src1->type == GGML_TYPE_I32);
+            int64_t row_idx = ((const int32_t *)src1->data)[i];
+            GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
+
+            float *out = (float *)((char *)dst->data + i * nb1);
+            rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
+        }
+
+        return true;
+    }
+
 public:
     int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
+        GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
         GGML_ASSERT(ctx.kernels);
         const size_t n = tensor->ne[1];
         const size_t k = tensor->ne[0];
@@ -351,17 +417,12 @@ public:
         size_t kr      = ctx.kernels->gemm.get_kr();
         size_t sr      = ctx.kernels->gemm.get_sr();
 
-#ifndef NDEBUG
-        const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
-        GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
-#endif
         struct kai_rhs_pack_qs4cxs1s0_param params;
         params.lhs_zero_point = 1;
         params.rhs_zero_point = 8;
         variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, &params);
 
         return 0;
-
         GGML_UNUSED(data_size);
     }
 };
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
 static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
     tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
 
-    GGML_UNUSED(buffer);
     return GGML_STATUS_SUCCESS;
+    GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
     GGML_UNUSED(buft);
 }
 
+static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
+    GGML_ASSERT(ctx.kernels);
+
+    const size_t n  = tensor->ne[1];
+    const size_t k  = tensor->ne[0];
+    const size_t nr = ctx.kernels->gemm.get_nr();
+    const size_t kr = ctx.kernels->gemm.get_kr();
+
+    return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
+
+    GGML_UNUSED(buft);
+}
+
 namespace ggml::cpu::kleidiai {
 class extra_buffer_type : ggml::cpu::extra_buffer_type {
     bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
-        if (op->op == GGML_OP_MUL_MAT &&
+        if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
             op->src[0]->type == GGML_TYPE_Q4_0 &&
             op->src[0]->buffer &&
             (ggml_n_dims(op->src[0]) == 2) &&
             op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
+            if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
+                return false;
+            }
             if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
                 return false;
             }
-            if (op->src[1]->type == GGML_TYPE_F32 &&
+            if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
                 ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
                 return true;
             }
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
     }
 
     ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
-        if (op->op == GGML_OP_MUL_MAT) {
+        if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
             if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
                 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
             }
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
                            /* .alloc_buffer     = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
                            /* .get_alignment    = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
                            /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
-                           /* .get_alloc_size   = */ nullptr,  // defaults to ggml_nbytes
+                           /* .get_alloc_size   = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
                            /* .is_host          = */ nullptr,
                            },
         /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),