uint32_t num_blocks;
};
+struct vk_op_flash_attn_split_k_reduce_push_constants {
+ uint32_t D;
+ uint32_t ne1;
+ uint32_t ne2;
+ uint32_t ne3;
+ uint32_t k_num;
+ uint32_t sinks;
+};
+
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
if (device->subgroup_clustered && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
GGML_ASSERT(0);
}
- if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
+ if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
// and change addressing calculations to index Q's dimension 2.
gqa_ratio = qk_ratio;
N = gqa_ratio;
- workgroups_y /= N;
+ workgroups_y /= gqa_ratio;
}
bool small_rows = N <= get_fa_num_small_rows(path);
}
assert(pipeline);
+ // Compile early to initialize wg_denoms.
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
uint32_t split_kv = KV;
uint32_t split_k = 1;
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
- // Try to use split_k when KV is large enough to be worth the overhead
- if (workgroups_x == 1 && shader_core_count > 0) {
+ // Try to use split_k when KV is large enough to be worth the overhead.
+ // Must either be a single batch or be using gqa, we can't mix the two.
+ if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
// Try to run two workgroups per SM.
- split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
+ split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
split_k = CEIL_DIV(KV, split_kv);
- workgroups_x = split_k;
}
}
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
- const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
+ // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
+ // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
+ const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
GGML_ABORT("Requested preallocation size is too large");
}
{
// Request descriptor sets
- ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
}
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
-
+ workgroups_x *= pipeline->wg_denoms[0];
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
// there's no more than one tile of rows (i.e. workgroups_x would have been
// one). We reuse workgroups_x to mean the number of splits, so we need to
// cancel out the divide by wg_denoms[0].
- pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
+ pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
ggml_vk_sync_buffers(ctx, subctx);
- const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
+ const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{split_k_buf, sinks_buf, dst_buf},
- pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
+ pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
ctx->prealloc_split_k_need_sync = true;
} else {
+ if (gqa_ratio > 1) {
+ // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
+ workgroups_x *= pipeline->wg_denoms[0];
+ }
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
pc, { workgroups_x, workgroups_y, workgroups_z });
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
- uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+ uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
- uint32_t m_offset = 0;
+ uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
- m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+ m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
- uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
+ // note: O and Q have swapped coord 1,2.
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
}
}
- o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
+ o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
}
}
- uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
+ uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
}
uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
- iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
+ gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
q_stride, k_stride, v_stride, m_stride;
void init_indices()
N = p.N;
KV = p.KV;
- i = gl_WorkGroupID.x;
- split_k_index = 0;
-
if (p.k_num > 1) {
i = 0;
- split_k_index = gl_WorkGroupID.x;
+ // batch and split_k share gl_WorkGroupID.x
+ gqa_iq1 = gl_WorkGroupID.x / p.k_num;
+ split_k_index = gl_WorkGroupID.x % p.k_num;
+ } else if (p.gqa_ratio > 1) {
+ i = 0;
+ gqa_iq1 = gl_WorkGroupID.x;
+ split_k_index = 0;
+ } else {
+ i = gl_WorkGroupID.x;
+ gqa_iq1 = 0;
+ split_k_index = 0;
}
Tr = CEIL_DIV(N, Br);
barrier();
}
- uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+ uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
- uint32_t m_offset = 0;
+ uint32_t m_offset = gqa_iq1*KV;
if (p.nem2 != 1 || p.nem3 != 1) {
- m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
+ m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV;
}
[[dont_unroll]]
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
- uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
+ // note: O and Q have swapped coord 1,2.
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
}
}
- o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
+ o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
}
}
- uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
+ uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
- uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
+ uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03;
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
- uint32_t m_offset = 0;
+ uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/;
if (p.nem2 != 1 || p.nem3 != 1) {
- m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
+ m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
- uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
+ // note: O and Q have swapped coord 1,2.
+ uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
- o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
+ o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
- uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
+ uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
layout (push_constant) uniform parameter {
uint D;
- uint N;
+ uint ne1;
+ uint ne2;
uint ne3;
uint k_num;
uint sinks;
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
- const uint iq3 = gl_WorkGroupID.z;
+ const uint i2 = gl_WorkGroupID.z % p.ne2;
+ const uint i3 = gl_WorkGroupID.z / p.ne2;
uint D = p.D;
- uint N = p.N;
uint k_num = p.k_num;
- uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
- uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
- uint lm_stride = N * 2;
+ uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n;
+ uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n;
+ uint lm_stride = p.ne1 * 2;
// Compute the max m value for the row
float m_max = -1.0/0.0;
if (d < D) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
- uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
+ uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
O = clamp(O, -FLT_MAX, FLT_MAX);
- data_d[iq3 * D * N + D * n + d] = O;
+ data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O;
}
}
// Qwen3-VL-8B https://github.com/ggml-org/llama.cpp/issues/17012
test_cases.emplace_back(new test_flash_attn_ext(72, 72, 16, {1, 1}, 5776, 5776, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+ test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+ test_cases.emplace_back(new test_flash_attn_ext(64, 64, 8, {8, 1}, 7680, 4, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+
for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) {
for (int nr : { 1, 4, }) {