float stop,
float step);
-#define GGML_KQ_MASK_PAD 1
-
- // q: [n_embd_k, n_batch, n_head, ne3 ]
- // k: [n_embd_k, n_kv, n_head_kv, ne3 ]
- // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
- // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
- // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
+ // q: [n_embd_k, n_batch, n_head, ne3 ]
+ // k: [n_embd_k, n_kv, n_head_kv, ne3 ]
+ // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
+ // mask: [n_kv, n_batch, ne32, ne33]
+ // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
//
// broadcast:
// n_head % n_head_kv == 0
if (mask) {
GGML_ASSERT(ggml_is_contiguous(mask));
- GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
- "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
ggml_tensor * m = nullptr;
if (mask) {
- m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]);
+ m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, nr23[1]);
ggml_set_name(m, "m");
}