// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
+
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
}
}
+static void ggml_compute_forward_flash_attn_ext_tiled(
+ const ggml_compute_params * params,
+ ggml_tensor * dst,
+ int ir0, int ir1) {
+ const ggml_tensor * q = dst->src[0];
+ const ggml_tensor * k = dst->src[1];
+ const ggml_tensor * v = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
+
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
+ const int64_t DK = nek0;
+ const int64_t DV = nev0;
+ const int64_t N = neq1;
+
+ GGML_ASSERT(ne0 == DV);
+ GGML_ASSERT(ne2 == N);
+
+ // input tensor rows must be contiguous
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+ GGML_ASSERT(neq0 == DK);
+ GGML_ASSERT(nek0 == DK);
+ GGML_ASSERT(nev0 == DV);
+
+ GGML_ASSERT(neq1 == N);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ GGML_ASSERT(k->type == v->type);
+ const ggml_type kv_type = k->type;
+
+ const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
+ const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
+ const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
+ const size_t kv_type_size = ggml_type_size(kv_type);
+
+ // broadcast factors
+ const int64_t rk2 = neq2/nek2;
+ const int64_t rk3 = neq3/nek3;
+
+ const int64_t rv2 = neq2/nev2;
+ const int64_t rv3 = neq3/nev3;
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
+
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0) {
+ scale /= logit_softcap;
+ }
+
+ const uint32_t n_head = neq2;
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ int ith = params->ith;
+
+ static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
+ static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
+
+ GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
+
+ int ir = ir0;
+ while (ir < ir1) {
+ // q indices for the start of this tile
+ const int iq3 = ir/(neq2*neq1);
+ const int iq2 = (ir - iq3*neq2*neq1)/neq1;
+ const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
+
+ // Number of valid rows in this tile:
+ // - limited by tile size (Q_TILE_SZ)
+ // - limited by chunk boundary (ir1 - ir)
+ // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
+ const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
+ GGML_ASSERT(tile_rows > 0);
+
+ const uint32_t h = iq2; // head index
+ const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
+ float S[Q_TILE_SZ];
+ float M[Q_TILE_SZ];
+
+ for (int i = 0 ; i < Q_TILE_SZ; ++i) {
+ S[i] = 0.;
+ M[i] = -INFINITY;
+ }
+
+ // Per-thread scratch layout:
+ // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
+ // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
+ // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
+ // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
+ // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
+ float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
+
+ void * Q_q = base;
+ float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
+ float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
+ float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
+ float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
+
+ memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
+ memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
+
+ // k indices
+ const int ik3 = iq3 / rk3;
+ const int ik2 = iq2 / rk2;
+
+ // v indices
+ const int iv3 = iq3 / rv3;
+ const int iv2 = iq2 / rv2;
+
+ for (int tq = 0; tq < tile_rows; tq++) {
+ const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
+ kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
+ }
+ // Zero-pad remaining rows
+ for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
+ memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
+ }
+
+ for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
+
+ // skip the tile entirely if all the masks are -inf
+ if (mask) {
+ bool can_skip = true;
+ for (int tq = 0; tq < tile_rows; tq++) {
+ const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+ mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
+ if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
+ can_skip = false;
+ }
+ }
+ }
+
+ if (can_skip) {
+ continue;
+ }
+ }
+
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+ const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+ const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
+ float s;
+ kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
+ KQ[tq * KV_TILE_SZ + tk] = s * scale;
+ }
+ }
+
+ if (logit_softcap != 0.0f) {
+ ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
+ ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
+ }
+
+ if (mask) {
+ ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
+ }
+
+ bool skip[Q_TILE_SZ] = {};
+
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+ float * kq_row = KQ + tq * KV_TILE_SZ;
+
+ float tile_max;
+ ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
+
+ if (tile_max == -INFINITY) {
+ skip[tq] = true;
+ continue;
+ }
+
+ const float Mold = M[tq];
+ const float Mnew = fmaxf(Mold, tile_max);
+
+ if (Mnew > Mold) {
+ const float ms = expf(Mold - Mnew);
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
+ S[tq] *= ms;
+ }
+ M[tq] = Mnew;
+
+
+ S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
+ }
+
+ // Convert V tile to F32 first (if F16), then do MAD
+ // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
+ // TODO: on ARM, native f16 should be faster
+ if (kv_type == GGML_TYPE_F16) {
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+ const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
+ ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
+ }
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+ if (skip[tq]) continue;
+ float * vkq_row = VKQ32 + tq * DV;
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+ const float p = KQ[tq * KV_TILE_SZ + tk];
+ ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
+ }
+ }
+ } else {
+ for (int tq = 0; tq < Q_TILE_SZ; tq++) {
+ if (skip[tq]) continue;
+ float * vkq_row = VKQ32 + tq * DV;
+ for (int tk = 0; tk < KV_TILE_SZ; tk++) {
+ const float p = KQ[tq * KV_TILE_SZ + tk];
+ const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
+ ggml_vec_mad_f32(DV, vkq_row, v_row, p);
+ }
+ }
+ }
+ }
+
+ // sinks (apply only to valid rows in the tile)
+ if (sinks) {
+ const float s = ((float *)((char *) sinks->data))[h];
+
+ for (int tq = 0; tq < tile_rows; tq++) {
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s > M[tq]) {
+ ms = expf(M[tq] - s);
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
+ } else {
+ vs = expf(s - M[tq]);
+ }
+
+ S[tq] = S[tq] * ms + vs;
+ }
+ }
+
+ for (int tq = 0; tq < tile_rows; tq++) {
+ // V /= S
+ const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
+ ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
+
+ // dst indices
+ const int i1 = iq1 + tq;
+ const int i2 = iq2;
+ const int i3 = iq3;
+
+ // permute(0, 2, 1, 3)
+ memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
+ }
+
+ ir += tile_rows;
+ }
+}
+
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
// The number of elements in each chunk
const int64_t dr = (nr + nchunk - 1) / nchunk;
+ static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
+ static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
+ const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
+ const bool use_tiled = (q->type == GGML_TYPE_F32 &&
+ kv_is_f32_or_f16 &&
+ k->type == v->type &&
+ nek1 % KV_TILE_SZ == 0 &&
+ neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
+
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
- ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+ if (use_tiled) {
+ ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
+ } else {
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+ }
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}