uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
+ uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
float max_bias;
float logit_softcap;
- uint32_t mask;
- uint32_t n_head_log2;
+ uint32_t mask_n_head_log2;
float m0;
float m1;
const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nem2 = mask ? mask->ne[2] : 0;
+ const uint32_t nem3 = mask ? mask->ne[3] : 0;
const uint32_t HSK = nek0;
const uint32_t HSV = nev0;
}
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
- qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
}
}
+ uint32_t mask_n_head_log2 = ((mask != nullptr) << 16) | n_head_log2;
+
const vk_flash_attn_push_constants pc = { N, KV,
(uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
(uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3,
- nem1, nem2,
+ nem1, nem2, nem3,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
scale, max_bias, logit_softcap,
- mask != nullptr, n_head_log2, m0, m1,
+ mask_n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k };
ggml_vk_sync_buffers(subctx);
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
- // TODO: support broadcast
- // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
- // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
- if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
- return false;
- }
// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;
+ uint32_t nem3;
uint32_t nb01;
uint32_t nb02;
float max_bias;
float logit_softcap;
- uint32_t mask;
- uint32_t n_head_log2;
+ uint32_t mask_n_head_log2;
float m0;
float m1;
uint32_t k_num;
} p;
+#define MASK_ENABLE_BIT (1<<16)
+#define N_LOG2_MASK 0xFFFF
+
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+ uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK;
+
+ const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1);
+ const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
}
uint32_t m_offset = 0;
- if (p.nem2 != 1) {
- m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+ if (p.nem2 != 1 || p.nem3 != 1) {
+ m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
}
}
- if (p.mask != 0) {
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);