From: Jeff Bolz Date: Sat, 12 Apr 2025 08:44:48 +0000 (-0500) Subject: vulkan: use aligned loads for flash attention mask (llama/12853) X-Git-Tag: upstream/1.7.5+105~33 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=751e42b21eb2edc005d2027d48cba12c3361f6ba;p=pkg%2Fggml%2Fsources%2Fwhisper.cpp vulkan: use aligned loads for flash attention mask (llama/12853) Rewrite the stride logic for the mask tensor in the FA shader to force the stride to be aligned, to allow using more efficient loads. --- diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index a8f4bc41..e1baa85f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -201,6 +201,11 @@ void main() { uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; uint32_t k_stride = p.nb11; uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { @@ -209,6 +214,7 @@ void main() { k_stride &= ~7; v_stride &= ~7; #endif + m_stride &= ~7; } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); @@ -261,10 +267,7 @@ void main() { if (p.mask != 0) { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); - // When using grouped query attention, all rows use the same mask. - if (p.gqa_ratio > 1) { - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1); - } + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); coopmat mv;