]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commit
vulkan: Add bfloat16 support (llama/12554)
authorJeff Bolz <redacted>
Thu, 1 May 2025 18:49:39 +0000 (13:49 -0500)
committerGeorgi Gerganov <redacted>
Wed, 7 May 2025 12:39:32 +0000 (15:39 +0300)
commitfd1cb9fc12bf8ceb36281601a1faa4741c76b2fe
tree711e5c89de3328534dcaae185a06aa090bbac9d0
parent17f6b8225eafbd5dae4c7d728eb3c8f3434ba15a
vulkan: Add bfloat16 support (llama/12554)

* vulkan: Add bfloat16 support

This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16.
The extension is required for coopmat multiply support, but matrix-vector
multiply trivially promotes bf16 to fp32 and doesn't require the extension.
The copy/get_rows shaders also don't require the extension.

It's probably possible to fall back to non-coopmat and promote to fp32 when
the extension isn't supported, but this change doesn't do that.

The coopmat support also requires a glslc that supports the extension, which
currently requires a custom build.

* vulkan: Support bf16 tensors without the bf16 extension or coopmat support

Compile a variant of the scalar mul_mm shader that will promote the bf16
values to float, and use that when either the bf16 extension or the coopmat
extensions aren't available.

* vulkan: bfloat16 fixes (really works without bfloat16 support now)

* vulkan: fix spirv-val failure and reenable -O
13 files changed:
ggml/src/ggml-vulkan/CMakeLists.txt
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp
ggml/src/ggml-vulkan/vulkan-shaders/copy.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/types.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp