]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
kleidiai: generalize compute_forward_kv_cache to compute_forward_fp16 (#15817)
authorCharles Xu <redacted>
Sat, 6 Sep 2025 14:08:43 +0000 (16:08 +0200)
committerGitHub <redacted>
Sat, 6 Sep 2025 14:08:43 +0000 (22:08 +0800)
ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

index 7a830448eb3e1f63241a1cc37a42eb1a06a4b805..95f873dc77923617064321d4da2dd5c6e9ea8f9f 100644 (file)
@@ -154,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
             if (dst->src[0]->type == GGML_TYPE_Q4_0) {
                 return compute_forward_q4_0(params, dst);
             } else if (dst->src[0]->type == GGML_TYPE_F16) {
-                return compute_forward_kv_cache(params, dst);
+                return compute_forward_fp16(params, dst);
             }
         } else if (dst->op == GGML_OP_GET_ROWS) {
             if (dst->src[0]->type == GGML_TYPE_Q4_0) {
@@ -164,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
         return false;
     }
 
-    bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
+    bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
         static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
 
         const ggml_tensor * src0 = dst->src[0];
@@ -534,13 +534,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
             if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
                 return (ggml::cpu::tensor_traits *) op->src[0]->extra;
             }
-            else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
-                     op->src[0]->op == GGML_OP_VIEW &&
-                     (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op ==  GGML_OP_SOFT_MAX) &&
-                     op->src[1]->ne[1] > 1) {
-                if ((op->src[0]->nb[0] != 2) ||
-                    (op->src[1]->nb[0] != 4) ||
-                    (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
+            else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
+                if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
                     (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
                     return nullptr;
                 }