]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cpu: FA split across kv for faster TG (#19209)
authorAman Gupta <redacted>
Mon, 2 Feb 2026 17:19:55 +0000 (01:19 +0800)
committerGitHub <redacted>
Mon, 2 Feb 2026 17:19:55 +0000 (01:19 +0800)
* ggml-cpu: split across kv for faster TG

* simplify sinks application

* add ref impl

ggml/include/ggml-cpu.h
ggml/src/ggml-cpu/ggml-cpu-impl.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ggml-cpu.cpp
ggml/src/ggml-cpu/ops.cpp
tests/test-backend-ops.cpp

index 4f3b99c8d07573f2aab2e5d4f87294ebcaa467b2..e3e067c916f14e785da031beea50743bb6644636 100644 (file)
@@ -19,6 +19,9 @@ extern "C" {
         // abort ggml_graph_compute when true
         ggml_abort_callback abort_callback;
         void *              abort_callback_data;
+
+        // use only reference implementations
+        bool use_ref;
     };
 
     // numa strategies
@@ -132,6 +135,8 @@ extern "C" {
     GGML_BACKEND_API void ggml_backend_cpu_set_threadpool    (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
     GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
 
+    GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref);
+
     GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
 
     GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *,       float *, int64_t);
index 0e8dd0ae0530c598925a6535dea06fcf65361b40..88a9c9ec0572499204b4a1d7d3d5aba51d2b90af 100644 (file)
@@ -24,6 +24,9 @@ struct ggml_compute_params {
     void * wdata;
 
     struct ggml_threadpool * threadpool;
+
+    // use reference implementation
+    bool use_ref;
 };
 
 
index b1de2ae871601f4f1eeca9d8e12a870291ad585a..3e5f01e3fb650afd3221130690dffc7e8169fd47 100644 (file)
@@ -5,7 +5,6 @@
 #include "ggml-backend.h"
 #include "traits.h"
 #include "ggml-cpu-impl.h"
-#include "ggml-cpu.h"
 #include "ggml-impl.h"
 #include "quants.h"
 #include "ggml-threading.h"
@@ -2867,12 +2866,20 @@ struct ggml_cplan ggml_graph_plan(
                     } break;
                 case GGML_OP_FLASH_ATTN_EXT:
                     {
+                        const int64_t neq2 = node->src[0]->ne[2]; // number of query heads
                         const int64_t DK = node->src[1]->ne[0];
                         const int64_t DV = node->src[2]->ne[0];
 
                         // Tiled flash attention scratch (tile sizes defined in common.h)
                         // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
-                        cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
+                        size_t prefill  = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
+
+                        // Decode path: n_kv_chunks = n_tasks (one chunk per thread)
+                        // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
+                        size_t n_chunks = n_tasks;
+                        size_t decode   = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV));
+
+                        cur += MAX(prefill, decode);
                     } break;
                 case GGML_OP_FLASH_ATTN_BACK:
                     {
@@ -2929,11 +2936,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
     set_numa_thread_affinity(state->ith);
 
     struct ggml_compute_params params = {
-        /*.ith       =*/ state->ith,
-        /*.nth       =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
-        /*.wsize     =*/ cplan->work_size,
-        /*.wdata     =*/ cplan->work_data,
-        /*.threadpool=*/ tp,
+        /*.ith        =*/ state->ith,
+        /*.nth        =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
+        /*.wsize      =*/ cplan->work_size,
+        /*.wdata      =*/ cplan->work_data,
+        /*.threadpool =*/ tp,
+        /*.use_ref    =*/ cplan->use_ref,
     };
 
     GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
index f4713a42185d70407bb373ff44ffbe4412b6f5d3..ddf1737a3171db2a79566b399711171627f30950 100644 (file)
@@ -105,6 +105,8 @@ struct ggml_backend_cpu_context {
 
     ggml_abort_callback abort_callback;
     void *              abort_callback_data;
+
+    bool                use_ref;  // use reference implementation
 };
 
 static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
@@ -143,6 +145,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend
 
     cpu_plan->cplan.abort_callback      = cpu_ctx->abort_callback;
     cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+    cpu_plan->cplan.use_ref             = cpu_ctx->use_ref;
 
     return cpu_plan;
 }
@@ -182,6 +185,7 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
 
     cplan.abort_callback      = cpu_ctx->abort_callback;
     cplan.abort_callback_data = cpu_ctx->abort_callback_data;
+    cplan.use_ref             = cpu_ctx->use_ref;
 
     return ggml_graph_compute(cgraph, &cplan);
 }
@@ -223,6 +227,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
     ctx->work_size           = 0;
     ctx->abort_callback      = NULL;
     ctx->abort_callback_data = NULL;
+    ctx->use_ref             = false;
 
     ggml_backend_t cpu_backend = new ggml_backend {
         /* .guid    = */ ggml_backend_cpu_guid(),
@@ -270,6 +275,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_
     ctx->abort_callback_data = abort_callback_data;
 }
 
+void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) {
+    GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
+
+    struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
+    ctx->use_ref = use_ref;
+}
+
 // CPU backend - device
 
 struct ggml_backend_cpu_device_context {
@@ -646,6 +658,9 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch
     if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) {
         return (void *)ggml_is_numa;
     }
+    if (strcmp(name, "ggml_backend_cpu_set_use_ref") == 0) {
+        return (void *)ggml_backend_cpu_set_use_ref;
+    }
 
     // threadpool - TODO:  move to ggml-base
     if (strcmp(name, "ggml_threadpool_new") == 0) {
index 48c8964361900305a212f2c86a7e0c2a7bd1ea71..ce15b18ce0ee97c4b2f4ddeebd33070bac68052a 100644 (file)
@@ -8042,12 +8042,14 @@ void ggml_compute_forward_top_k(
     }
 }
 
-// ggml_compute_forward_flash_attn_ext
-
 static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
         const ggml_compute_params * params,
         ggml_tensor * dst,
-        int ir0, int ir1) {
+        int ir0, int ir1,
+        int64_t ic_start, int64_t ic_end,
+        float * partials, int64_t partial_stride) {
+
+    const bool write_partials = (partials != nullptr);
     const ggml_tensor * q     = dst->src[0];
     const ggml_tensor * k     = dst->src[1];
     const ggml_tensor * v     = dst->src[2];
@@ -8124,7 +8126,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
 
     int ith = params->ith;
 
-    // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
         // q indices
         const int iq3 = ir/(neq2*neq1);
@@ -8165,7 +8166,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
         // loop over n_kv and n_head_kv
         // ref: https://arxiv.org/pdf/2112.05682.pdf
 
-        for (int64_t ic = 0; ic < nek1; ++ic) {
+        for (int64_t ic = ic_start; ic < ic_end; ++ic) {
             const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
             if (mv == -INFINITY) {
                 continue;
@@ -8238,8 +8239,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
             }
         }
 
-        // sinks
-        if (sinks) {
+        // sinks - apply only on the first kv-chunk
+        if (sinks && ic_start == 0) {
             const float s = ((float *)((char *) sinks->data))[h];
 
             float ms = 1.0f;
@@ -8247,6 +8248,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
 
             if (s > M) {
                 ms = expf(M - s);
+                M = s;
                 ggml_vec_scale_f32(DV, VKQ32, ms);
             } else {
                 vs = expf(s - M);
@@ -8255,20 +8257,26 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
             S = S*ms + vs;
         }
 
-        // V /= S
-        const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
-        ggml_vec_scale_f32(DV, VKQ32, S_inv);
-
-        // dst indices
-        const int i1 = iq1;
-        const int i2 = iq2;
-        const int i3 = iq3;
+        if (write_partials) {
+            // Write M, S, VKQ to partials for later reduction
+            // partials layout: [M, S, VKQ[DV]] per query head
+            float * partial = partials + ir * partial_stride;
+            partial[0] = M;
+            partial[1] = S;
+            memcpy(partial + 2, VKQ32, DV * sizeof(float));
+        } else {
+            // V /= S
+            const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
+            ggml_vec_scale_f32(DV, VKQ32, S_inv);
 
-        // original
-        //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+            // dst indices
+            const int i1 = iq1;
+            const int i2 = iq2;
+            const int i3 = iq3;
 
-        // permute(0, 2, 1, 3)
-        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+            // permute(0, 2, 1, 3)
+            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+        }
     }
 }
 
@@ -8546,6 +8554,78 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
     }
 }
 
+// Reduction function: combines partial results across KV chunks
+// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
+static void ggml_flash_attn_ext_reduce_partials(
+        const ggml_compute_params * params,
+        ggml_tensor * dst,
+        const int64_t n_chunks,
+        const int64_t chunk_size) {
+
+    const ggml_tensor * q = dst->src[0];
+    const ggml_tensor * k = dst->src[1];
+    const ggml_tensor * v = dst->src[2];
+
+    const int64_t DK        = k->ne[0];
+    const int64_t DV        = v->ne[0];
+    const int64_t nek1      = k->ne[1];
+    const int64_t n_q_heads = q->ne[2];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
+    float *       thread_wdata     = (float *) params->wdata + ith * wdata_per_thread;
+
+    const int64_t partials_offset  = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
+    const int64_t partial_size     = 2 + DV;
+    const float * partials_base    = (const float *) params->wdata + partials_offset;
+
+    // Output layout
+    const int64_t ne1 = dst->ne[1];
+    const int64_t ne2 = dst->ne[2];
+    const size_t  nb1 = dst->nb[1];
+
+    // Each thread reduces a subset of query heads
+    for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
+        float   M_final   = -INFINITY;
+        float   S_final   = 0.0f;
+        float * VKQ_final = thread_wdata;
+        memset(VKQ_final, 0, DV * sizeof(float));
+
+        // Combine partials from all chunks
+        for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
+            const int64_t ic_start = chunk_idx * chunk_size;
+            if (ic_start >= nek1) continue;
+
+            const float * partial   = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
+            const float   M_chunk   = partial[0];
+            const float   S_chunk   = partial[1];
+            const float * VKQ_chunk = partial + 2;
+
+            if (S_chunk == 0.0f) continue;
+
+            const float M_new     = fmaxf(M_final, M_chunk);
+            const float scale_old = expf(M_final - M_new);
+            const float scale_new = expf(M_chunk - M_new);
+
+            for (int64_t d = 0; d < DV; ++d) {
+                VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
+            }
+            S_final = S_final * scale_old + S_chunk * scale_new;
+            M_final = M_new;
+        }
+
+        // Normalize and write to output
+        if (S_final != 0.0f) {
+            const float S_inv = 1.0f / S_final;
+            ggml_vec_scale_f32(DV, VKQ_final, S_inv);
+        }
+        // iq1=0, iq3=0 for decode
+        memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
+    }
+}
+
 static void ggml_compute_forward_flash_attn_ext_f16(
         const ggml_compute_params * params,
         ggml_tensor * dst) {
@@ -8567,6 +8647,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     const int64_t DV = nev0;
     const int64_t N  = neq1;
 
+
     GGML_ASSERT(ne0 == DV);
     GGML_ASSERT(ne2 == N);
 
@@ -8587,60 +8668,92 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     GGML_ASSERT(nb1 <= nb2);
     GGML_ASSERT(nb2 <= nb3);
 
-    // parallelize by q rows using ggml_vec_dot_f32
-
-    // total rows in q
-    const int64_t nr = neq1*neq2*neq3;
-
-    // rows per thread
     const int ith = params->ith;
     const int nth = params->nth;
 
-    // disable for NUMA
-    const bool disable_chunking = ggml_is_numa();
+    // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
+    const bool use_ref = params->use_ref;
 
-    // 4x chunks per thread
-    int nth_scaled = nth * 4;
-    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
-    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
+    const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
+    const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
 
-    if (nth == 1 || nchunk < nth || disable_chunking) {
-        nchunk = nth;
-    }
+    if (use_split_kv_path) {
+        const int64_t chunk_size = (nek1 + nth - 1) / nth;
 
-    if (ith == 0) {
-        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
-        ggml_threadpool_chunk_set(params->threadpool, nth);
-    }
+        // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
+        const int64_t partial_size  = 2 + DV;
+        float *       partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
 
-    ggml_barrier(params->threadpool);
+        const int64_t ic_start = ith * chunk_size;
+        const int64_t ic_end   = std::min(ic_start + chunk_size, nek1);
 
-    // The number of elements in each chunk
-    const int64_t dr = (nr + nchunk - 1) / nchunk;
+        const int64_t partial_stride = nth * partial_size;
+        float *       chunk_partials = partials_base + ith * partial_size;
 
-    static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
-    static constexpr int64_t Q_TILE_SZ  = ggml_fa_tile_config::Q;
-    const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
-    const bool use_tiled = (q->type == GGML_TYPE_F32 &&
-                            kv_is_f32_or_f16 &&
-                            k->type == v->type &&
-                            nek1 % KV_TILE_SZ == 0 &&
-                            neq1 >= Q_TILE_SZ);  // Only use tiled for batch >= tile size
+        if (ic_start < nek1) {
+            for (int64_t q_head = 0; q_head < neq2; q_head++) {
+                ggml_compute_forward_flash_attn_ext_f16_one_chunk(
+                    params, dst, q_head, q_head + 1, ic_start, ic_end,
+                    chunk_partials, partial_stride);
+            }
+        } else {
+            for (int64_t q_head = 0; q_head < neq2; q_head++) {
+                float * q_partials = chunk_partials + q_head * partial_stride;
+                q_partials[0] = -INFINITY;  // M
+                q_partials[1] = 0.0f;       // S
+            }
+        }
 
-    // The first chunk comes from our thread_id, the rest will get auto-assigned.
-    int current_chunk = ith;
+        ggml_barrier(params->threadpool);
+        ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
+    } else {
 
-    while (current_chunk < nchunk) {
-        const int64_t ir0 = dr * current_chunk;
-        const int64_t ir1 = MIN(ir0 + dr, nr);
+        // total rows in q
+        const int64_t nr = neq1*neq2*neq3;
 
-        if (use_tiled) {
-            ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
-        } else {
-            ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+        // disable for NUMA
+        const bool disable_chunking = ggml_is_numa();
+
+        // 4x chunks per thread
+        int nth_scaled = nth * 4;
+        int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
+        int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
+
+        if (nth == 1 || nchunk < nth || disable_chunking) {
+            nchunk = nth;
         }
 
-        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+        if (ith == 0) {
+            ggml_threadpool_chunk_set(params->threadpool, nth);
+        }
+
+        ggml_barrier(params->threadpool);
+
+        const int64_t dr = (nr + nchunk - 1) / nchunk;
+
+        static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
+        static constexpr int64_t Q_TILE_SZ  = ggml_fa_tile_config::Q;
+        const bool use_tiled = !use_ref &&
+                               (q->type == GGML_TYPE_F32 &&
+                                kv_is_f32_or_f16 &&
+                                k->type == v->type &&
+                                nek1 % KV_TILE_SZ == 0 &&
+                                neq1 >= Q_TILE_SZ);
+
+        int current_chunk = ith;
+
+        while (current_chunk < nchunk) {
+            const int64_t ir0 = dr * current_chunk;
+            const int64_t ir1 = MIN(ir0 + dr, nr);
+
+            if (use_tiled) {
+                ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
+            } else {
+                ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
+            }
+
+            current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+        }
     }
 }
 
index 411467e968f43c0270e3c888efe00339f63a2f6f..90cc0d7da2ff2c21b9d10fe78dd56d3b5f87b7ad 100644 (file)
@@ -8591,6 +8591,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
             output_printer->print_operation(info);
             return false;
         }
+        // Use reference implementation on the CPU backend for comparison
+        using ggml_backend_cpu_set_use_ref_t = void (*)(ggml_backend_t, bool);
+        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_cpu));
+        auto * set_use_ref = (ggml_backend_cpu_set_use_ref_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_use_ref");
+        if (set_use_ref) {
+            set_use_ref(backend_cpu, true);
+        }
 
         size_t n_ok = 0;
         size_t                   tests_run = 0;