]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commit
CUDA: Optimize `rms_norm_f32` kernel and its fused variants, giving 1-6% perf E2E...
authorOliver Simons <redacted>
Wed, 3 Sep 2025 17:59:16 +0000 (19:59 +0200)
committerGitHub <redacted>
Wed, 3 Sep 2025 17:59:16 +0000 (19:59 +0200)
commit661ae31c9c68201577e70278285b349a5a662caf
treee097105be093157bb0045f99440aa0d2dcd59cba
parent407c23786dd0d3a503e9429eead96e611d3950c9
CUDA: Optimize `rms_norm_f32` kernel and its fused variants, giving 1-6% perf E2E (#15715)

* Add fastdiv, use it in modulo and use modulo in rms_norm_f32

Fastdiv is much faster way to do integer division, which was identified
as bottleneck in rms_norm_f32

* Support more `block_size` values in `rms_norm_f32`

This makes us more flexible in selecting the optimal threads w.r.t
paralellizing across a col vs. launch-overheads of threads and mio
throttles

* Update ggml/src/ggml-cuda/common.cuh

Co-authored-by: Johannes Gäßler <redacted>
* Replace modulo with fastmodulo in `rms_norm_f32`

* Use `BinPackArguments=true` for formating function calls

Will file a separate PR to adjust .clang-format file

* Update ggml/src/ggml-cuda/common.cuh

Co-authored-by: Johannes Gäßler <redacted>
* Use uint3 for both `fastdiv` and `fastmodulo`

The compiler seems to reliably optimize away the unused .z component in
the fastdiv use-case, see https://godbolt.org/z/rx8KPrKr3

* More constrained type declarations

Co-authored-by: Johannes Gäßler <redacted>
* Rename fastdiv and fastmodulo variables to shared variable name

As suggest by JohannesGaessler, this increases clarity of the intended
use

* Pack fastdiv/fastmodulo constants into uint2/uint3 objects

By packing constants to be used together into a struct, we are less
likely to make errors.

* Rename function parameter of fastmodulo

`modulo_consts` is more fitting/descriptive

---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/norm.cu