// ggml_compute_forward_flash_attn_ext
-static void ggml_compute_forward_flash_attn_ext_f16(
+static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const ggml_compute_params * params,
- ggml_tensor * dst) {
-
+ ggml_tensor * dst,
+ int ir0, int ir1) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
- const int ith = params->ith;
- const int nth = params->nth;
-
const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;
// parallelize by q rows using ggml_vec_dot_f32
- // total rows in q
- const int nr = neq1*neq2*neq3;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
+ int ith = params->ith;
+
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
}
}
+static void ggml_compute_forward_flash_attn_ext_f16(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * q = dst->src[0];
+ const ggml_tensor * k = dst->src[1];
+ const ggml_tensor * v = dst->src[2];
+
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
+ const int64_t DK = nek0;
+ const int64_t DV = nev0;
+ const int64_t N = neq1;
+
+ GGML_ASSERT(ne0 == DV);
+ GGML_ASSERT(ne2 == N);
+
+ // input tensor rows must be contiguous
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
+
+ GGML_ASSERT(neq0 == DK);
+ GGML_ASSERT(nek0 == DK);
+ GGML_ASSERT(nev0 == DV);
+
+ GGML_ASSERT(neq1 == N);
+
+ // dst cannot be transposed or permuted
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(nb0 <= nb1);
+ GGML_ASSERT(nb1 <= nb2);
+ GGML_ASSERT(nb2 <= nb3);
+
+ // parallelize by q rows using ggml_vec_dot_f32
+
+ // total rows in q
+ const int64_t nr = neq1*neq2*neq3;
+
+ // rows per thread
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ // disable for NUMA
+ const bool disable_chunking = ggml_is_numa();
+
+ // 4x chunks per thread
+ int nth_scaled = nth * 4;
+ int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
+ int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
+
+ if (nth == 1 || nchunk < nth || disable_chunking) {
+ nchunk = nth;
+ }
+
+ if (ith == 0) {
+ // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
+ ggml_threadpool_chunk_set(params->threadpool, nth);
+ }
+
+ ggml_barrier(params->threadpool);
+
+ // The number of elements in each chunk
+ const int64_t dr = (nr + nchunk - 1) / nchunk;
+
+ // The first chunk comes from our thread_id, the rest will get auto-assigned.
+ int current_chunk = ith;
+
+ while (current_chunk < nchunk) {
+ const int64_t ir0 = dr * current_chunk;
+ const int64_t ir1 = MIN(ir0 + dr, nr);
+
+ ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
+
+ current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
+ }
+}
+
void ggml_compute_forward_flash_attn_ext(
const ggml_compute_params * params,
ggml_tensor * dst) {