]> git.djapps.eu Git - pkg/ggml/sources/ggml/commit
Vulkan Scalar Flash Attention Refactor (llama/19625)
authorRuben Ortlam <redacted>
Tue, 24 Feb 2026 07:35:48 +0000 (08:35 +0100)
committerGeorgi Gerganov <redacted>
Wed, 25 Feb 2026 10:32:13 +0000 (12:32 +0200)
commitc5619b7de27d9fc6df1cfc67c247cd8003d9decc
tree24cab1cb17fa4d8596e4467ae6db100feebbfd89
parente7ba2a81251a6c1ce03ba9b14f0fb901d2bdc25a
Vulkan Scalar Flash Attention Refactor (llama/19625)

* vulkan: allow using fp16 in scalar flash attention shader

* split rows inside of subgroups for faster synchronization

* use row_split when Br >= 4, change reductions to use shared memory if row_split == 1

* use f32 scalar FA if f16 is not supported by device

* fix amd workgroup size issue

* optimize masksh use

* add medium rows FA shader Br size

* fixes

* add padding to mask shmem buffer

* cache q values into registers for KQ

* fuse lf accumulation, pf and v accumulation into a loop

* stage K loads through shmem

* stage V loads through shmem

* only stage through shmem on Nvidia

* default to Bc 32

* also stage V through shmem when this is done for K

* dynamic subgroups for intel

* use vectorized stores

* use float_type for dequantize4 functions

* use smaller scalar rows size for smaller rows count

* relax flash attention split_k condition to allow non-gqa use

* use minimal subgroup size on Intel

* fix shmem support function

* fix rebase issues

* fixes

* Bc 4 for scalar FA is not a valid configuration

* Use wave32 on AMD RDNA for scalar FA

* add Intel shader core count lookup-table

* fix regressions

* device tuning

* tmpsh size fix

* fix editorconfig

* refactor fa tuning logic into a single place

* fix gqa opt logic

* fix block_rows with small n_rows

* amd tuning

* fix hsk=72/80 issue

* tuning

* allow condition skipping for column check

* use float16 for Of if available

* address feedback

* fix bad RDNA performance on head size <= 128 by limiting occupancy

* allow printing pipeline stats

* cleanup and fixes

* limit occupancy for GCN for small batch FA with large HSK

* disable f16 FA for GCN AMD GPUs on the proprietary driver
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/flash_attn.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp