]> git.djapps.eu Git - pkg/ggml/sources/ggml/commit
vulkan: Multi-pass softmax for large number of cols (llama/17892)
authorJeff Bolz <redacted>
Sat, 13 Dec 2025 09:04:29 +0000 (03:04 -0600)
committerGeorgi Gerganov <redacted>
Sun, 14 Dec 2025 14:40:47 +0000 (16:40 +0200)
commitb9730e5719f6c9b2b0cbde4dd8bf369f084cda5b
tree7e26b5644ef29e4383450c5bbdebbb9260f2f979
parent27b5a372b403ddf03356cb61500e42031c57240c
vulkan: Multi-pass softmax for large number of cols (llama/17892)

When the number of cols is large, split each row across multiple workgroups.
There are three phases that communicate partial results through temp buffers:
(1) compute max partials
(2) take max of partials, compute sum(exp(x-max)) partials
(3) sum partials, compute scaled result
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp