]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cpu: introduce chunking for flash attention (llama/16829)
authorMax Krasnyansky <redacted>
Thu, 30 Oct 2025 12:26:05 +0000 (05:26 -0700)
committerGeorgi Gerganov <redacted>
Sat, 1 Nov 2025 07:41:35 +0000 (09:41 +0200)
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop
on top that handles the chunks.

src/ggml-cpu/ops.cpp

index 3156bd60101d7f1cfd866913144d588b5a7b7188..c17ab10245d58240d11e114aed31fb359dc6236c 100644 (file)
@@ -7909,10 +7909,10 @@ void ggml_compute_forward_argsort(
 
 // ggml_compute_forward_flash_attn_ext
 
-static void ggml_compute_forward_flash_attn_ext_f16(
+static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
         const ggml_compute_params * params,
-        ggml_tensor * dst) {
-
+        ggml_tensor * dst,
+        int ir0, int ir1) {
     const ggml_tensor * q     = dst->src[0];
     const ggml_tensor * k     = dst->src[1];
     const ggml_tensor * v     = dst->src[2];
@@ -7928,9 +7928,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
     GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 
-    const int ith = params->ith;
-    const int nth = params->nth;
-
     const int64_t DK = nek0;
     const int64_t DV = nev0;
     const int64_t N  = neq1;
@@ -7964,16 +7961,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
 
     // parallelize by q rows using ggml_vec_dot_f32
 
-    // total rows in q
-    const int nr = neq1*neq2*neq3;
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
-
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
     float scale         = 1.0f;
     float max_bias      = 0.0f;
     float logit_softcap = 0.0f;
@@ -8000,6 +7987,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     GGML_ASSERT((                            q_to_vec_dot) && "fattn: unsupported K-type");
     GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float  ) && "fattn: unsupported V-type");
 
+    int ith = params->ith;
+
     // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
         // q indices
@@ -8147,6 +8136,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     }
 }
 
+static void ggml_compute_forward_flash_attn_ext_f16(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * q     = dst->src[0];
+    const ggml_tensor * k     = dst->src[1];
+    const ggml_tensor * v     = dst->src[2];
+
+    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
+    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
+    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
+
+    const int64_t DK = nek0;
+    const int64_t DV = nev0;
+    const int64_t N  = neq1;
+
+    GGML_ASSERT(ne0 == DV);
+    GGML_ASSERT(ne2 == N);
+
+    // input tensor rows must be contiguous
+    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+    GGML_ASSERT(neq0 == DK);
+    GGML_ASSERT(nek0 == DK);
+    GGML_ASSERT(nev0 == DV);
+
+    GGML_ASSERT(neq1 == N);
+
+    // dst cannot be transposed or permuted
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(nb0 <= nb1);
+    GGML_ASSERT(nb1 <= nb2);
+    GGML_ASSERT(nb2 <= nb3);
+
+    // parallelize by q rows using ggml_vec_dot_f32
+
+    // total rows in q
+    const int64_t nr = neq1*neq2*neq3;
+
+    // rows per thread
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // disable for NUMA
+    const bool disable_chunking = ggml_is_numa();
+
+    // 4x chunks per thread
+    int nth_scaled = nth * 4;
+    int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
+    int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
+
+    if (nth == 1 || nchunk < nth || disable_chunking) {
+        nchunk = nth;
+    }
+
+    if (ith == 0) {
+        // Every thread starts at ith, so the first unprocessed chunk is nth.  This save a bit of coordination right at the start.
+        ggml_threadpool_chunk_set(params->threadpool, nth);
+    }
+
+    ggml_barrier(params->threadpool);
+
+    // The number of elements in each chunk
+    const int64_t dr = (nr + nchunk - 1) / nchunk;
+
+    // The first chunk comes from our thread_id, the rest will get auto-assigned.
+    int current_chunk = ith;
+
+    while (current_chunk < nchunk) {
+        const int64_t ir0 = dr * current_chunk;
+        const int64_t ir1 = MIN(ir0 + dr, nr);
+
+        ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+
+        current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+    }
+}
+
 void ggml_compute_forward_flash_attn_ext(
         const ggml_compute_params * params,
         ggml_tensor * dst) {