]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commit
vulkan: Add bfloat16 support (#12554)
authorJeff Bolz <redacted>
Thu, 1 May 2025 18:49:39 +0000 (13:49 -0500)
committerGitHub <redacted>
Thu, 1 May 2025 18:49:39 +0000 (20:49 +0200)
commit79f26e9e125b21760aeb016f34bfd42a93f48351
tree5b84c585608eaaceec5a74f8e2ab512f67ceda07
parentfc727bcdd5a311c7c69a76dbf87f4784e828c7b4
vulkan: Add bfloat16 support (#12554)

* vulkan: Add bfloat16 support

This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16.
The extension is required for coopmat multiply support, but matrix-vector
multiply trivially promotes bf16 to fp32 and doesn't require the extension.
The copy/get_rows shaders also don't require the extension.

It's probably possible to fall back to non-coopmat and promote to fp32 when
the extension isn't supported, but this change doesn't do that.

The coopmat support also requires a glslc that supports the extension, which
currently requires a custom build.

* vulkan: Support bf16 tensors without the bf16 extension or coopmat support

Compile a variant of the scalar mul_mm shader that will promote the bf16
values to float, and use that when either the bf16 extension or the coopmat
extensions aren't available.

* vulkan: bfloat16 fixes (really works without bfloat16 support now)

* vulkan: fix spirv-val failure and reenable -O
13 files changed:
ggml/src/ggml-vulkan/CMakeLists.txt
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp
ggml/src/ggml-vulkan/vulkan-shaders/copy.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/types.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp