#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
+static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
#define VK_VENDOR_ID_AMD 0x1002
#define VK_VENDOR_ID_APPLE 0x106b
uint32_t n_head_log2;
float m0;
float m1;
+
+ uint32_t gqa_ratio;
};
struct vk_op_push_constants {
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
const uint32_t D = neq0;
- const uint32_t N = neq1;
+ uint32_t N = neq1;
const uint32_t KV = nek1;
GGML_ASSERT(ne0 == D);
vk_pipeline pipeline = pipelines[aligned];
assert(pipeline);
+ uint32_t gqa_ratio = 1;
+ uint32_t qk_ratio = neq2 / nek2;
+ uint32_t workgroups_x = (uint32_t)neq1;
+ uint32_t workgroups_y = (uint32_t)neq2;
+ uint32_t workgroups_z = (uint32_t)neq3;
+
+ if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
+ // grouped query attention - make the N dimension equal to gqa_ratio, reduce
+ // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
+ // and change addressing calculations to index Q's dimension 2.
+ gqa_ratio = qk_ratio;
+ N = gqa_ratio;
+ workgroups_y /= N;
+ }
+
if (dryrun) {
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
nbm1,
scale, max_bias, logit_softcap,
- mask != nullptr, n_head_log2, m0, m1 };
+ mask != nullptr, n_head_log2, m0, m1, gqa_ratio };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
- sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
}
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
uint32_t n_head_log2;
float m0;
float m1;
+
+ uint32_t gqa_ratio;
} p;
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
#define DECODEFUNC
#endif
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+ if (r < N && c < D) {
+ uint32_t offset = (iq2 + r) * D + c;
+ data_o[o_offset + offset] = D_TYPE(elem);
+ }
+ return elem;
+}
+
+// Load the slope matrix, indexed by Q's dimension 2.
+ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
+{
+ const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
+
+ const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
+ const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+
+ return ACC_TYPE(pow(base, ACC_TYPE(exph)));
+}
+
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
const uint32_t i = gl_WorkGroupID.x;
- const uint32_t iq2 = gl_WorkGroupID.y;
+ // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
+ // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
+ const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
- // nb?1 are already divided by the type size and are in units of elements
- uint32_t q_stride = p.nb01;
+ // nb?1 are already divided by the type size and are in units of elements.
+ // When using grouped query attention, Q is indexed by iq2, so the stride
+ // should be nb02 (which is in bytes).
+ uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// hint to the compiler that strides are aligned for the aligned variant of the shader
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
- ACC_TYPE slope = ACC_TYPE(1.0);
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
// ALiBi
if (p.max_bias > 0.0f) {
- const uint32_t h = iq2;
-
- const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
- const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
-
- slope = pow(base, ACC_TYPE(exph));
+ coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
[[dont_unroll]]
if (p.mask != 0) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+ // When using grouped query attention, all rows use the same mask.
+ if (p.gqa_ratio > 1) {
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
+ }
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
- S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
// Clear padding elements to -inf, so they don't contribute to rowmax
O = Ldiag*O;
- tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
- tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
-
- // permute dimensions
- tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
uint32_t o_offset = iq3*p.ne2*p.ne1;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
- coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
+ if (p.gqa_ratio > 1) {
+ coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
+ } else {
+ tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
+
+ // permute dimensions
+ tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
+
+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
+ }
}