}
}
-// ggml_compute_forward_flash_attn_ext
-
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const ggml_compute_params * params,
ggml_tensor * dst,
- int ir0, int ir1) {
+ int ir0, int ir1,
+ int64_t ic_start, int64_t ic_end,
+ float * partials, int64_t partial_stride) {
+
+ const bool write_partials = (partials != nullptr);
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
int ith = params->ith;
- // loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
const int iq3 = ir/(neq2*neq1);
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
- for (int64_t ic = 0; ic < nek1; ++ic) {
+ for (int64_t ic = ic_start; ic < ic_end; ++ic) {
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
continue;
}
}
- // sinks
- if (sinks) {
+ // sinks - apply only on the first kv-chunk
+ if (sinks && ic_start == 0) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
if (s > M) {
ms = expf(M - s);
+ M = s;
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
vs = expf(s - M);
S = S*ms + vs;
}
- // V /= S
- const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
- ggml_vec_scale_f32(DV, VKQ32, S_inv);
-
- // dst indices
- const int i1 = iq1;
- const int i2 = iq2;
- const int i3 = iq3;
+ if (write_partials) {
+ // Write M, S, VKQ to partials for later reduction
+ // partials layout: [M, S, VKQ[DV]] per query head
+ float * partial = partials + ir * partial_stride;
+ partial[0] = M;
+ partial[1] = S;
+ memcpy(partial + 2, VKQ32, DV * sizeof(float));
+ } else {
+ // V /= S
+ const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
+ ggml_vec_scale_f32(DV, VKQ32, S_inv);
- // original
- //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
+ // dst indices
+ const int i1 = iq1;
+ const int i2 = iq2;
+ const int i3 = iq3;
- // permute(0, 2, 1, 3)
- memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+ // permute(0, 2, 1, 3)
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
+ }
}
}
}
}
+// Reduction function: combines partial results across KV chunks
+// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
+static void ggml_flash_attn_ext_reduce_partials(
+ const ggml_compute_params * params,
+ ggml_tensor * dst,
+ const int64_t n_chunks,
+ const int64_t chunk_size) {
+
+ const ggml_tensor * q = dst->src[0];
+ const ggml_tensor * k = dst->src[1];
+ const ggml_tensor * v = dst->src[2];
+
+ const int64_t DK = k->ne[0];
+ const int64_t DV = v->ne[0];
+ const int64_t nek1 = k->ne[1];
+ const int64_t n_q_heads = q->ne[2];
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
+ float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
+
+ const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
+ const int64_t partial_size = 2 + DV;
+ const float * partials_base = (const float *) params->wdata + partials_offset;
+
+ // Output layout
+ const int64_t ne1 = dst->ne[1];
+ const int64_t ne2 = dst->ne[2];
+ const size_t nb1 = dst->nb[1];
+
+ // Each thread reduces a subset of query heads
+ for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
+ float M_final = -INFINITY;
+ float S_final = 0.0f;
+ float * VKQ_final = thread_wdata;
+ memset(VKQ_final, 0, DV * sizeof(float));
+
+ // Combine partials from all chunks
+ for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
+ const int64_t ic_start = chunk_idx * chunk_size;
+ if (ic_start >= nek1) continue;
+
+ const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
+ const float M_chunk = partial[0];
+ const float S_chunk = partial[1];
+ const float * VKQ_chunk = partial + 2;
+
+ if (S_chunk == 0.0f) continue;
+
+ const float M_new = fmaxf(M_final, M_chunk);
+ const float scale_old = expf(M_final - M_new);
+ const float scale_new = expf(M_chunk - M_new);
+
+ for (int64_t d = 0; d < DV; ++d) {
+ VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
+ }
+ S_final = S_final * scale_old + S_chunk * scale_new;
+ M_final = M_new;
+ }
+
+ // Normalize and write to output
+ if (S_final != 0.0f) {
+ const float S_inv = 1.0f / S_final;
+ ggml_vec_scale_f32(DV, VKQ_final, S_inv);
+ }
+ // iq1=0, iq3=0 for decode
+ memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
+ }
+}
+
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const int64_t DV = nev0;
const int64_t N = neq1;
+
GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
- // parallelize by q rows using ggml_vec_dot_f32
-
- // total rows in q
- const int64_t nr = neq1*neq2*neq3;
-
- // rows per thread
const int ith = params->ith;
const int nth = params->nth;
- // disable for NUMA
- const bool disable_chunking = ggml_is_numa();
+ // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
+ const bool use_ref = params->use_ref;
- // 4x chunks per thread
- int nth_scaled = nth * 4;
- int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
- int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
+ const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
+ const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
- if (nth == 1 || nchunk < nth || disable_chunking) {
- nchunk = nth;
- }
+ if (use_split_kv_path) {
+ const int64_t chunk_size = (nek1 + nth - 1) / nth;
- if (ith == 0) {
- // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
- ggml_threadpool_chunk_set(params->threadpool, nth);
- }
+ // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
+ const int64_t partial_size = 2 + DV;
+ float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
- ggml_barrier(params->threadpool);
+ const int64_t ic_start = ith * chunk_size;
+ const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
- // The number of elements in each chunk
- const int64_t dr = (nr + nchunk - 1) / nchunk;
+ const int64_t partial_stride = nth * partial_size;
+ float * chunk_partials = partials_base + ith * partial_size;
- static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
- static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
- const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
- const bool use_tiled = (q->type == GGML_TYPE_F32 &&
- kv_is_f32_or_f16 &&
- k->type == v->type &&
- nek1 % KV_TILE_SZ == 0 &&
- neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
+ if (ic_start < nek1) {
+ for (int64_t q_head = 0; q_head < neq2; q_head++) {
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(
+ params, dst, q_head, q_head + 1, ic_start, ic_end,
+ chunk_partials, partial_stride);
+ }
+ } else {
+ for (int64_t q_head = 0; q_head < neq2; q_head++) {
+ float * q_partials = chunk_partials + q_head * partial_stride;
+ q_partials[0] = -INFINITY; // M
+ q_partials[1] = 0.0f; // S
+ }
+ }
- // The first chunk comes from our thread_id, the rest will get auto-assigned.
- int current_chunk = ith;
+ ggml_barrier(params->threadpool);
+ ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
+ } else {
- while (current_chunk < nchunk) {
- const int64_t ir0 = dr * current_chunk;
- const int64_t ir1 = MIN(ir0 + dr, nr);
+ // total rows in q
+ const int64_t nr = neq1*neq2*neq3;
- if (use_tiled) {
- ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
- } else {
- ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+ // disable for NUMA
+ const bool disable_chunking = ggml_is_numa();
+
+ // 4x chunks per thread
+ int nth_scaled = nth * 4;
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
+
+ if (nth == 1 || nchunk < nth || disable_chunking) {
+ nchunk = nth;
}
- current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+ if (ith == 0) {
+ ggml_threadpool_chunk_set(params->threadpool, nth);
+ }
+
+ ggml_barrier(params->threadpool);
+
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
+
+ static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
+ static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
+ const bool use_tiled = !use_ref &&
+ (q->type == GGML_TYPE_F32 &&
+ kv_is_f32_or_f16 &&
+ k->type == v->type &&
+ nek1 % KV_TILE_SZ == 0 &&
+ neq1 >= Q_TILE_SZ);
+
+ int current_chunk = ith;
+
+ while (current_chunk < nchunk) {
+ const int64_t ir0 = dr * current_chunk;
+ const int64_t ir1 = MIN(ir0 + dr, nr);
+
+ if (use_tiled) {
+ ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
+ } else {
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
+ }
+
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+ }
}
}