]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Add bfloat16 support (#12554)
authorJeff Bolz <redacted>
Thu, 1 May 2025 18:49:39 +0000 (13:49 -0500)
committerGitHub <redacted>
Thu, 1 May 2025 18:49:39 +0000 (20:49 +0200)
* vulkan: Add bfloat16 support

This adds bfloat16 matrix multiply support based on VK_KHR_shader_bfloat16.
The extension is required for coopmat multiply support, but matrix-vector
multiply trivially promotes bf16 to fp32 and doesn't require the extension.
The copy/get_rows shaders also don't require the extension.

It's probably possible to fall back to non-coopmat and promote to fp32 when
the extension isn't supported, but this change doesn't do that.

The coopmat support also requires a glslc that supports the extension, which
currently requires a custom build.

* vulkan: Support bf16 tensors without the bf16 extension or coopmat support

Compile a variant of the scalar mul_mm shader that will promote the bf16
values to float, and use that when either the bf16 extension or the coopmat
extensions aren't available.

* vulkan: bfloat16 fixes (really works without bfloat16 support now)

* vulkan: fix spirv-val failure and reenable -O

13 files changed:
ggml/src/ggml-vulkan/CMakeLists.txt
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp
ggml/src/ggml-vulkan/vulkan-shaders/copy.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/types.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 9d028f718d0fe01aafb827be9fb2574107cc9de8..31816219c06fdafd311955a70b0d39b8a5333ae1 100644 (file)
@@ -71,6 +71,22 @@ if (Vulkan_FOUND)
         add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
     endif()
 
+    # Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
+    # If it's not, there will be an error to stderr.
+    # If it's supported, set a define to indicate that we should compile those shaders
+    execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
+                    OUTPUT_VARIABLE glslc_output
+                    ERROR_VARIABLE glslc_error)
+
+    if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
+        message(STATUS "GL_EXT_bfloat16 not supported by glslc")
+        set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
+    else()
+        message(STATUS "GL_EXT_bfloat16 supported by glslc")
+        set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
+        add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+    endif()
+
     target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
     target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
 
@@ -142,6 +158,7 @@ if (Vulkan_FOUND)
                     -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
                     -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
                     -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
+                    -DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
             BUILD_COMMAND ${CMAKE_COMMAND} --build .
             INSTALL_COMMAND ${CMAKE_COMMAND} --install .
             INSTALL_DIR ${CMAKE_BINARY_DIR}
index 4614c3c1563017cd79d20b507ec0b573253a2ec6..eac0b422bc67bb973ad033279a42866c9b389256 100644 (file)
 
 #include "ggml-vulkan-shaders.hpp"
 
+// remove this once it's more widely available in the SDK
+#if !defined(VK_KHR_shader_bfloat16)
+
+#define VK_KHR_shader_bfloat16 1
+#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION                          1
+#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME                        "VK_KHR_shader_bfloat16"
+#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
+#define VK_COMPONENT_TYPE_BFLOAT16_KHR                               ((VkComponentTypeKHR)1000141000)
+
+typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
+    VkStructureType                       sType;
+    void*                                 pNext;
+    VkBool32                              shaderBFloat16Type;
+    VkBool32                              shaderBFloat16DotProduct;
+    VkBool32                              shaderBFloat16CooperativeMatrix;
+} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
+#endif
+
 #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
 #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
 static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
@@ -266,8 +284,9 @@ struct vk_device_struct {
     bool subgroup_require_full_support;
 
     bool coopmat_support;
-    bool coopmat_acc_f32_support;
-    bool coopmat_acc_f16_support;
+    bool coopmat_acc_f32_support {};
+    bool coopmat_acc_f16_support {};
+    bool coopmat_bf16_support {};
     uint32_t coopmat_m;
     uint32_t coopmat_n;
     uint32_t coopmat_k;
@@ -293,6 +312,7 @@ struct vk_device_struct {
 
     vk_matmul_pipeline pipeline_matmul_f32 {};
     vk_matmul_pipeline pipeline_matmul_f32_f16 {};
+    vk_matmul_pipeline pipeline_matmul_bf16 {};
     vk_matmul_pipeline2 pipeline_matmul_f16;
     vk_matmul_pipeline2 pipeline_matmul_f16_f32;
 
@@ -301,6 +321,7 @@ struct vk_device_struct {
     vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
 
     vk_matmul_pipeline pipeline_matmul_id_f32 {};
+    vk_matmul_pipeline pipeline_matmul_id_bf16 {};
     vk_matmul_pipeline2 pipeline_matmul_id_f16;
     vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
 
@@ -333,8 +354,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_clamp_f32;
     vk_pipeline pipeline_pad_f32;
     vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
-    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
-    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
+    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f32_bf16;
+    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f32_bf16;
     vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
     vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_norm_f32;
@@ -1791,6 +1812,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
     if (!device->pipeline_matmul_id_f32) {
         device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
     }
+    if (!device->pipeline_matmul_bf16) {
+        device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
+    }
+    if (!device->pipeline_matmul_id_bf16) {
+        device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
+    }
 
     std::vector<std::future<void>> compiles;
     auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
@@ -1900,6 +1927,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
 
         CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
+#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+        if (device->coopmat_bf16_support) {
+            CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
+        }
+#endif
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1921,6 +1953,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f16,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 
         CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+        if (device->coopmat_bf16_support) {
+            CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+        }
+#endif
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -1974,6 +2011,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+        if (device->coopmat_bf16_support) {
+            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
+        }
+#endif
 
         if (device->coopmat_acc_f16_support) {
             CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2022,6 +2064,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+        if (device->coopmat_bf16_support) {
+            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+        }
+#endif
 
         if (device->coopmat_acc_f16_support) {
             CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2104,6 +2151,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
 
+        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+
         CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2139,6 +2188,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
 
+        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
+
         CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2191,6 +2242,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
 
+        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+
         CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2226,6 +2279,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
 
+        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
+
         CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2246,8 +2301,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
-#undef CREATE_MM
     }
+    // reusing CREATE_MM from the fp32 path
+    if ((device->coopmat2 || device->coopmat_support)
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+        && !device->coopmat_bf16_support
+#endif
+        ) {
+        // use scalar tile sizes
+        l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
+        m_warptile = { 128,  64,  64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
+        s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
+
+        l_wg_denoms = {128, 128, 1 };
+        m_wg_denoms = { 64,  64, 1 };
+        s_wg_denoms = { 32,  32, 1 };
+
+        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
+    }
+#undef CREATE_MM
 
     // mul mat vec
 
@@ -2266,6 +2339,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1),  mul_mat_vec_f32_f32_f32_len,  mul_mat_vec_f32_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1),  mul_mat_vec_f16_f32_f32_len,  mul_mat_vec_f16_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
@@ -2288,6 +2362,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1),  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1),  mul_mat_vec_f16_f16_f32_len,  mul_mat_vec_f16_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
@@ -2311,6 +2386,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",  mul_mat_vec_id_f32_f32_len,  mul_mat_vec_id_f32_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32",  mul_mat_vec_id_f16_f32_len,  mul_mat_vec_id_f16_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
@@ -2356,6 +2432,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     // get_rows
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32",  get_rows_f32_len,  get_rows_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16",  get_rows_f16_len,  get_rows_f16_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2373,6 +2450,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32",  get_rows_f32_f32_len,  get_rows_f32_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2410,10 +2488,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
     if (device->float_controls_rte_fp16) {
         ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
@@ -2578,6 +2659,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         bool coopmat2_support = false;
         device->coopmat_support = false;
         device->integer_dot_product = false;
+        bool bfloat16_support = false;
 
         for (const auto& properties : ext_props) {
             if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -2608,6 +2690,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
                        !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
                 device->integer_dot_product = true;
 #endif
+            } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
+                       !getenv("GGML_VK_DISABLE_BFLOAT16")) {
+                bfloat16_support = true;
             }
         }
 
@@ -2794,6 +2879,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
         }
 #endif
 
+#if defined(VK_KHR_shader_bfloat16)
+        VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
+        bfloat16_features.pNext = nullptr;
+        bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
+        if (bfloat16_support) {
+            last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
+            last_struct = (VkBaseOutStructure *)&bfloat16_features;
+            device_extensions.push_back("VK_KHR_shader_bfloat16");
+        }
+#endif
+
         VkPhysicalDeviceMaintenance4Features maint4_features {};
         maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
         if (maintenance4_support) {
@@ -2991,6 +3087,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
                     device->coopmat_int_n = prop.NSize;
                     device->coopmat_int_k = prop.KSize;
                 }
+#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+                if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
+                    prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
+                    prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
+                    prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
+                    (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
+                ) {
+                    // coopmat sizes not set yet
+                    if (device->coopmat_m == 0) {
+                        device->coopmat_bf16_support = true;
+                        device->coopmat_m = prop.MSize;
+                        device->coopmat_n = prop.NSize;
+                        device->coopmat_k = prop.KSize;
+                    } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
+                        // Only enable if shape is identical
+                        device->coopmat_bf16_support = true;
+                    }
+                }
+#endif
             }
 
             if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
@@ -2998,11 +3113,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
                 GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
                 device->coopmat_support = false;
             }
+            if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
+                device->coopmat_bf16_support = false;
+            }
         }
 
         if (device->coopmat_support) {
             device_extensions.push_back("VK_KHR_cooperative_matrix");
         }
+#if defined(VK_KHR_shader_bfloat16)
+        if (device->coopmat_bf16_support) {
+            device_extensions.push_back("VK_KHR_shader_bfloat16");
+        }
+#endif
 #endif
         device->name = GGML_VK_NAME + std::to_string(idx);
 
@@ -3459,6 +3582,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
         return ctx->device->pipeline_matmul_f32_f16;
     }
+    if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
+        return ctx->device->pipeline_matmul_bf16;
+    }
     if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
             return ctx->device->pipeline_matmul_f16_f32.f16acc;
@@ -3530,6 +3656,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
     switch (a_type) {
         case GGML_TYPE_F32:
         case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q5_0:
@@ -3562,6 +3689,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
         return ctx->device->pipeline_matmul_id_f32;
     }
+    if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
+        return ctx->device->pipeline_matmul_id_bf16;
+    }
     if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
             return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
@@ -3615,6 +3745,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
     switch (a_type) {
         case GGML_TYPE_F32:
         case GGML_TYPE_F16:
+        case GGML_TYPE_BF16:
         case GGML_TYPE_Q4_0:
         case GGML_TYPE_Q4_1:
         case GGML_TYPE_Q5_0:
@@ -4350,6 +4481,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_cpy_f16_f16;
         }
     }
+    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
+        if (contig) {
+            return ctx->device->pipeline_contig_cpy_f32_bf16;
+        } else {
+            return ctx->device->pipeline_cpy_f32_bf16;
+        }
+    }
     if (src->type == GGML_TYPE_F32) {
         switch (to) {
         case GGML_TYPE_Q4_0:
@@ -4477,8 +4615,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
                               !ggml_vk_dim01_contiguous(src0);
     const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
+                              (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
                               !ggml_vk_dim01_contiguous(src1);
 
+    // If src0 is BF16, try to use a BF16 x BF16 multiply
+    ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
+
     const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 
     bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
@@ -4488,25 +4630,25 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
 
     if (mmp == nullptr) {
         // Fall back to f16 dequant mul mat
-        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
+        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
         quantize_y = false;
     }
 
     const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
-    const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
+    const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
 
     if (qx_needs_dequant) {
         // Fall back to dequant + f16 mulmat
-        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
+        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
     }
 
     // Not implemented
     GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 
-    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
+    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
     const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
 
-    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
+    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
 
     // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
     uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
@@ -4527,12 +4669,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     vk_pipeline to_q8_1 = nullptr;
 
     if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
+        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
     } else {
         to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
     }
     if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
+        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
     } else {
         to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
     }
@@ -5032,7 +5174,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
     // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
     // when ne12 and ne13 are one.
     } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
-               (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
+               (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
         ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
     } else {
         ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -5100,27 +5242,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
                               !ggml_vk_dim01_contiguous(src0);
     const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
+                              (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
                               !ggml_vk_dim01_contiguous(src1);
 
+    // If src0 is BF16, try to use a BF16 x BF16 multiply
+    ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
+
     const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 
-    vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
+    vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
 
     const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
-    const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
+    const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
 
     if (qx_needs_dequant) {
         // Fall back to dequant + f16 mulmat
-        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
+        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
     }
 
     // Not implemented
     GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 
-    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
+    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
     const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
 
-    vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
+    vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
 
     // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
     uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
@@ -5139,12 +5285,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     vk_pipeline to_fp16_vk_1 = nullptr;
 
     if (x_non_contig) {
-        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
+        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
     } else {
         to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
     }
     if (y_non_contig) {
-        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
+        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
     } else {
         to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
     }
@@ -9230,6 +9376,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 switch (src0_type) {
                     case GGML_TYPE_F32:
                     case GGML_TYPE_F16:
+                    case GGML_TYPE_BF16:
                     case GGML_TYPE_Q4_0:
                     case GGML_TYPE_Q4_1:
                     case GGML_TYPE_Q5_0:
@@ -9265,10 +9412,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (a->ne[3] != b->ne[3]) {
                     return false;
                 }
-                if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
+                if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
                     !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
                     return false;
                 }
+                if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
+                    // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
+                    // So don't support this combination for now.
+                    return false;
+                }
 
                 return true;
             } break;
@@ -9341,6 +9493,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 switch (op->src[0]->type) {
                     case GGML_TYPE_F32:
                     case GGML_TYPE_F16:
+                    case GGML_TYPE_BF16:
                     case GGML_TYPE_Q4_0:
                     case GGML_TYPE_Q4_1:
                     case GGML_TYPE_Q5_0:
@@ -9371,6 +9524,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     switch (src1_type) {
                     case GGML_TYPE_F32:
                     case GGML_TYPE_F16:
+                    case GGML_TYPE_BF16:
                     case GGML_TYPE_Q4_0:
                     case GGML_TYPE_Q4_1:
                     case GGML_TYPE_Q5_0:
index d6e0b2a5a5dd6a4f10851f8faafd364ef995e0ce..ad13f69b3fb1a60fbe61833d8a5a1c1d0c4c1457 100644 (file)
@@ -12,6 +12,9 @@ endif()
 if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
     add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 endif()
+if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+    add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+endif()
 set(TARGET vulkan-shaders-gen)
 add_executable(${TARGET} vulkan-shaders-gen.cpp)
 install(TARGETS ${TARGET} RUNTIME)
index dd828c232628c868b7e61d12b44b449ca658e99d..6567a8c54cf493bec8214b9f45324dacd2a74ad9 100644 (file)
@@ -18,7 +18,11 @@ void main() {
     // fast path for when all four iterations are in-bounds
     if (idx + (num_iter-1)*num_threads < p.ne) {
         [[unroll]] for (uint i = 0; i < num_iter; ++i) {
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
+
+#if defined(DATA_D_BF16)
+            float f = float(data_a[get_aoffset() + idx]);
+            data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
+#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
             data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
 #else
             data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
@@ -31,7 +35,10 @@ void main() {
                 continue;
             }
 
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
+#if defined(DATA_D_BF16)
+            float f = float(data_a[get_aoffset() + idx]);
+            data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
+#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
             data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
 #else
             data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
index 29c9064942d93d244273fcd098acf8bf90a7ed1c..f476a2e3dd83e816af6d4fb819f5cdcdb062cf53 100644 (file)
@@ -12,7 +12,10 @@ void main() {
         return;
     }
 
-#ifndef OPTIMIZATION_ERROR_WORKAROUND
+#if defined(DATA_D_BF16)
+    float f = float(data_a[get_aoffset() + src0_idx(idx)]);
+    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
+#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
     data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
 #else
     data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
index 2a162a2c81543ec66795b76de74817e5e5cc2584..0d9739d40609af24b9dbc6d4b8124dae815f02b4 100644 (file)
@@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
 }
 #endif
 
+#if defined(DATA_A_BF16)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+    return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
+}
+#endif
+
 #if defined(DATA_A_Q4_0)
 vec2 dequantize(uint ib, uint iqs, uint a_offset) {
     const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
 }
 #endif
 
-#if defined(DATA_A_F32) || defined(DATA_A_F16)
+#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
 vec2 get_dm(uint ib, uint a_offset) {
     return vec2(0, 0);
 }
index e877ed7796a8ff1a0d9844efc87f3bbf450ee155..ee6b86a18ddf2e4a6ad108967e6d83f05b23cca7 100644 (file)
@@ -20,9 +20,14 @@ void main() {
     const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
     const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
 
+#if defined(DATA_A_BF16)
+    FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
+#else
+    FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
+#endif
 #ifndef OPTIMIZATION_ERROR_WORKAROUND
-    data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
+    data_d[d_offset + i00] = D_TYPE(v);
 #else
-    data_d[d_offset + i00] = data_a[a_offset + i00];
+    data_d[d_offset + i00] = D_TYPE(v);
 #endif
 }
index 775b48cd05e1af31c004d73c5381f9d99175f8ac..bb429dd594588a62c36c6b0d9a758aedd85d86aa 100644 (file)
@@ -6,7 +6,7 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
+#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
 #define K_PER_ITER 8
 #else
 #define K_PER_ITER 2
index 23ce8ceec332bc8f401a664233a844a8d4ab86ef..529ac4d44feccda70bb3564d894f6458a80af9be 100644 (file)
 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
 #endif
 
+#if defined(DATA_A_BF16) && defined(COOPMAT)
+#extension GL_EXT_bfloat16 : enable
+#endif
+
 #ifdef COOPMAT
 #extension GL_KHR_cooperative_matrix : enable
 #extension GL_KHR_memory_scope_semantics : enable
 #define LOAD_VEC_B 1
 #endif
 
+#if !defined(TO_FLOAT_TYPE)
+#define TO_FLOAT_TYPE FLOAT_TYPE
+#endif
+
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -202,8 +210,8 @@ void main() {
 #endif
 
 #ifdef COOPMAT
-    coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
-    coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
+    coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
+    coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
     coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
 
     [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
@@ -248,6 +256,21 @@ void main() {
                 buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
             }
 #endif
+#elif defined(DATA_A_BF16)
+#if LOAD_VEC_A == 4
+            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+            buf_a[buf_idx    ] = TO_FLOAT_TYPE(data_a[idx].x);
+            buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
+            buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
+            buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
+#else
+            if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
+                buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
+            } else {
+                buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
+            }
+#endif
 #elif defined(DATA_A_Q4_0)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
@@ -695,13 +718,13 @@ void main() {
             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
 #endif
             const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
-            buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
-            buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
-            buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
-            buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
+            buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
+            buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
+            buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
+            buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
 #elif !MUL_MAT_ID
             if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
-                buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
+                buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
             } else {
                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
             }
@@ -709,7 +732,7 @@ void main() {
             const uint row_i = ic * BN + loadc_b + l;
             if (row_i < _ne1) {
                 const u16vec2 row_idx = row_ids[row_i];
-                buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
+                buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
             } else {
                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
             }
index 06b7ab09ea51ae4f62db0bb7ad16e36f43c55f96..344b466101beb066f2d1e3b0240805c3ebde9ebb 100644 (file)
@@ -14,6 +14,9 @@
 #extension GL_EXT_buffer_reference : enable
 #extension GL_KHR_shader_subgroup_ballot : enable
 #extension GL_KHR_shader_subgroup_vote : enable
+#ifdef DATA_A_BF16
+#extension GL_EXT_bfloat16 : enable
+#endif
 
 #include "types.comp"
 
@@ -80,6 +83,12 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 #define store_scales(a)
 #endif
 
+#if defined(DATA_A_BF16)
+#define MAT_TYPE bfloat16_t
+#else
+#define MAT_TYPE FLOAT_TYPE
+#endif
+
 #ifdef MUL_MAT_ID
 layout (binding = 3) readonly buffer IDS {int data_ids[];};
 
@@ -271,8 +280,8 @@ void main() {
 
                 // Manually partial unroll
                 [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
-                    coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-                    coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
+                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
 
                     coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
                     coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -286,8 +295,8 @@ void main() {
                 store_scales(tid);
             }
             while (block_k < end_k) {
-                coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-                coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
 
                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
@@ -310,8 +319,8 @@ void main() {
 
                 // Manually partial unroll
                 [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
-                    coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-                    coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
+                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
 
                     coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
                     coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -325,8 +334,8 @@ void main() {
                 store_scales(tid);
             }
             while (block_k < end_k) {
-                coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-                coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
 
                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
@@ -350,8 +359,8 @@ void main() {
 
                 // Manually partial unroll
                 [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
-                    coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-                    coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+                    coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
 
                     coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
                     coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -365,8 +374,8 @@ void main() {
                 store_scales(tid);
             }
             while (block_k < end_k) {
-                coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-                coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+                coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
 
                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
@@ -405,8 +414,8 @@ void main() {
                 fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
             }
 
-            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+            coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+            coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
 
             coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
 #ifdef MUL_MAT_ID
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp
new file mode 100644 (file)
index 0000000..fd0ba40
--- /dev/null
@@ -0,0 +1,7 @@
+#version 460
+
+#extension GL_EXT_bfloat16 : require
+
+void main()
+{
+}
index f5b29bfb13a6690429a60ae26cf022c2e8d3b4fa..3bde717832b45115d4e538a5e2d44c086e1d0e2a 100644 (file)
 #endif
 #endif
 
+#if defined(DATA_A_BF16)
+#define QUANT_K 1
+#define QUANT_R 1
+
+#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
+#define A_TYPE uint16_t
+#elif LOAD_VEC_A == 4
+#define A_TYPE u16vec4
+#elif LOAD_VEC_A == 8
+#error unsupported
+#endif
+#endif
+
 #define QUANT_K_Q4_0 32
 #define QUANT_R_Q4_0 2
 
@@ -1343,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize)
 }
 #endif
 
+// returns the bfloat value in the low 16b.
+// See ggml_compute_fp32_to_bf16
+uint32_t fp32_to_bf16(float f)
+{
+    uint32_t u = floatBitsToUint(f);
+    u = (u + (0x7fff + ((u >> 16) & 1))) >> 16;
+    return u;
+}
+
+float bf16_to_fp32(uint32_t u)
+{
+    return uintBitsToFloat(u << 16);
+}
+
 #endif // !defined(GGML_TYPES_COMP)
index cf74625cc56d5ab7fac8de0e4ade888160255fc0..3b28578545ed5cb80322d7424649568952d3cda3 100644 (file)
@@ -63,7 +63,8 @@ const std::vector<std::string> type_names = {
     "iq3_xxs",
     "iq3_s",
     "iq4_xs",
-    "iq4_nl"
+    "iq4_nl",
+    "bf16",
 };
 
 namespace {
@@ -296,7 +297,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
     std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
 
     std::map<std::string, std::string> base_dict = {
-        {"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
         {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
     };
     std::string shader_name = "matmul";
@@ -318,12 +318,45 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
 
     const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
 
+    auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
+        if (t == "bf16") {
+            // scalar path promotes to float
+            if (!coopmat && !coopmat2) {
+                return "float";
+            }
+            return "bfloat16_t";
+        }
+        if (coopmat2 || fp16) {
+            return "float16_t";
+        }
+        return "float";
+    };
+
     // Shaders with f16 B_TYPE
-    string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
-    string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+
+    string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+
+    // bf16
+    {
+        std::string load_vec_a_unaligned = "1";
+        // For aligned matmul loads
+        std::string load_vec_a = coopmat2 ? "1" : "4";
+
+        // scalar path promotes to float
+        std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
 
-    string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
-    string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+        // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
+#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
+        if (!(coopmat || coopmat2))
+#endif
+        {
+            string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"},   {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_bf16",         source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                      {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"},                          {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}),                   fp16, coopmat, coopmat2, f16acc);
+        }
+    }
 
     for (const auto& tname : type_names) {
         std::string load_vec_quant = "2";
@@ -332,26 +365,30 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
         else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
             load_vec_quant = "4";
 
+        if (tname == "bf16") {
+            continue;
+        }
+
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);
         // For unaligned, load one at a time for f32/f16, or two at a time for quants
-        std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant;
+        std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
         // For aligned matmul loads
-        std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant;
+        std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
 
         // don't generate f32 variants for coopmat2
         if (!coopmat2) {
-            string_to_spv(shader_name + "_" + tname + "_f32",         source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float"},            {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
-            string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f32",         source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float"},            {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
         }
 
         if (tname != "f16" && tname != "f32") {
-            string_to_spv(shader_name + "_" + tname + "_f16",         source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
-            string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f16",         source_name,  merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
         }
 
 #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
         if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
-            string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
         }
 #endif
     }
@@ -393,6 +430,7 @@ void process_shaders() {
             if (tname == "f32") {
                 continue;
             }
+            if (tname == "bf16") continue;
 
             if (tname == "f16") {
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
@@ -417,12 +455,12 @@ void process_shaders() {
         string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
 
         // Dequant shaders
-        if (tname != "f16") {
+        if (tname != "f16" && tname != "bf16") {
             string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
         }
 
         if (!string_ends_with(tname, "_k")) {
-            shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
+            shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
 
             if (tname == "f16") {
                 string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
@@ -447,9 +485,11 @@ void process_shaders() {
     string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
     string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+    string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
     string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
     string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+    string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
 
     for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
         string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});