]> git.djapps.eu Git - pkg/ggml/sources/ggml/commit
vulkan: Implement grouped query attention in the coopmat2 FA shader (llama/12559)
authorJeff Bolz <redacted>
Wed, 2 Apr 2025 17:40:32 +0000 (12:40 -0500)
committerGeorgi Gerganov <redacted>
Tue, 8 Apr 2025 08:47:46 +0000 (11:47 +0300)
commit6d31906ae62e14896127fce0c59e0e9f94cbca37
tree96dadc55be75fe1ea3d39426c4c9da7b857dd38a
parentc8c51eda3b29c0359901e92824e4b9be97ec8eda
vulkan: Implement grouped query attention in the coopmat2 FA shader (llama/12559)

When adjacent batches of Q share the same batches of K/V, batch them into
the same workgroup. For example, when:

dst(128,32,1,1) = FA(q(128,1,32,1), k(128,16640,8,1), v(128,16640,8,1))

previously we would run 32 workgroups computing 1 result each, now we will
run 8 workgroups computing 4 results each.

This doesn't directly translate to better performance (at least when you have
>=32 SMs), but in a subsequent change I'll enable split_k which will scale much
better with 4x fewer workgroups.
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp