#define GGML_KQ_MASK_PAD 64
- // q: [n_embd_k, n_batch, n_head, ne3]
- // k: [n_embd_k, n_kv, n_head_kv, ne3]
- // v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
- // mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
- // res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
+ // q: [n_embd_k, n_batch, n_head, ne3 ]
+ // k: [n_embd_k, n_kv, n_head_kv, ne3 ]
+ // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
+ // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
+ // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
//
// broadcast:
// n_head % n_head_kv == 0
- // ne3 % ne32 == 0
+ // n_head % ne32 == 0
+ // ne3 % ne33 == 0
//
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
struct ggml_context * ctx,
memset(VKQ32, 0, DV*sizeof(float));
}
- const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
+ const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
// k indices
const int ik3 = iq3 / rk3;
return false;
}
// TODO: support broadcast
- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
+ // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
+ // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1) {
return false;
}
uint64_t nb22;
uint64_t nb23;
int32_t ne32;
+ int32_t ne33;
uint64_t nb31;
uint64_t nb32;
+ uint64_t nb33;
int32_t ne1;
int32_t ne2;
float scale;
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne32 =*/ ne32,
+ /*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
+ /*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.scale =*/ scale,
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
const float m = pm[ic + tiisg];
const bool has_mask = mask != q;
// pointer to the mask
- device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
+ device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
float slope = 1.0f;
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
+ // TODO: support broadcast
+ // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
+ // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
+ if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
+ return false;
+ }
// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
if (mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(mask));
- GGML_ASSERT(ggml_is_3d(mask));
GGML_ASSERT(mask->ne[0] == a->ne[0]);
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
- GGML_ASSERT(mask->ne[2] == q->ne[3]);
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
- GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
+ GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
+ GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
}
if (max_bias > 0.0f) {
ggml_tensor * m = nullptr;
if (mask) {
- m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1);
+ m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
ggml_set_name(m, "m");
}
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
if (ne0 <= 32 && ne1 <= 32) {
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
}
}