]> git.djapps.eu Git - pkg/ggml/sources/ggml/commit
vulkan: optimize rms_norm, and allow the work to spread across multiple SMs (llama...
authorJeff Bolz <redacted>
Sat, 23 Aug 2025 18:16:17 +0000 (13:16 -0500)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:02 +0000 (12:54 +0300)
commit4a0fa115bee6f57795eb09f4ed5f2fece7d28587
treee02ac2e331b99a67642a533d7898ba3a81c00cec
parent0791daa27bc43d1e0dd36d655ab1a8d55c505a28
vulkan: optimize rms_norm, and allow the work to spread across multiple SMs (llama/15281)

* vulkan: optimize rms_norm, and allow the work to spread across multiple SMs

There are really two parts to this change:
(1) Some optimizations similar to what we have in soft_max, to unroll with
different numbers of iterations.
(2) A fusion optimization where we detect add followed by rms_norm, and make
the add shader atomically accumulate the values^2 into memory. Then the
rms_norm shader can just load that sum. This allows the rms_norm to be
parallelized across multiple workgroups, it just becomes a simple per-element
multiply.

The fusion optimization is currently only applied when the rms_norm is on a
single vector. This previously always ran on a single SM. It could apply more
broadly, but when there are other dimensions the work can already spread across
SMs, and there would be some complexity to tracking multiple atomic sums.

* Change add+rms_norm optimization to write out an array of partial sums
rather than using atomic add, to make it deterministic. The rms_norm
shader fetches a subgroup's worth in parallel and uses subgroupAdd to
add them up.

* complete rebase against fused adds - multi_add shader can also compute partial sums

* fix validation errors

* disable add_rms_fusion for Intel due to possible driver bug

* resolve against #15489, sync after clearing partial sums
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/add.comp
src/ggml-vulkan/vulkan-shaders/multi_add.comp
src/ggml-vulkan/vulkan-shaders/rms_norm.comp
src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp