]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Vulkan: VK_KHR_cooperative_matrix support to speed up prompt processing (llama/10597)
author0cc4m <redacted>
Sat, 7 Dec 2024 09:24:15 +0000 (10:24 +0100)
committerGeorgi Gerganov <redacted>
Wed, 18 Dec 2024 10:52:16 +0000 (12:52 +0200)
* Vulkan: Implement VK_KHR_cooperative_matrix support in the matrix matrix multiplication shader

* Improve performance with better q4_k and q5_k dequant and store unrolling

* Add Vulkan MUL_MAT and MUL_MAT_ID accumulator precision selection

* Rework mulmat shader selection and compilation logic, avoid compiling shaders that won't get used by device

* Vulkan: Implement accumulator switch for specific mul mat mat shaders

* Vulkan: Unroll more loops for more mul mat mat performance

* Vulkan: Add VK_AMD_shader_core_properties2 support to read Compute Unit count for split_k logic

* Disable coopmat support on AMD proprietary driver

* Remove redundant checks

* Add environment variable GGML_VK_DISABLE_COOPMAT to disable VK_KHR_cooperative_matrix support

* Fix rebase typo

* Fix coopmat2 MUL_MAT_ID pipeline selection

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index c7ac0e8f7bd70aa38b0e9035d19143d8f61cdd02..4b2e449a58b409f1b37f9233ded262585179c1dc 100644 (file)
@@ -1,7 +1,8 @@
 #include "ggml-vulkan.h"
 #include <vulkan/vulkan_core.h>
-#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
+#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS)
 #include <chrono>
+#include "ggml-cpu.h"
 #endif
 
 #include <vulkan/vulkan.hpp>
@@ -169,8 +170,22 @@ struct vk_device_struct {
     bool uma;
     bool coopmat2;
 
+    bool coopmat_support;
+    bool coopmat_acc_f32_support;
+    bool coopmat_acc_f16_support;
+    uint32_t coopmat_m;
+    uint32_t coopmat_n;
+    uint32_t coopmat_k;
+
     size_t idx;
 
+    bool mul_mat_l;
+    bool mul_mat_m;
+    bool mul_mat_s;
+    bool mul_mat_id_l;
+    bool mul_mat_id_m;
+    bool mul_mat_id_s;
+
     vk_matmul_pipeline pipeline_matmul_f32;
     vk_matmul_pipeline pipeline_matmul_f32_f16;
     vk_matmul_pipeline2 pipeline_matmul_f16;
@@ -181,10 +196,10 @@ struct vk_device_struct {
     vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
 
     vk_matmul_pipeline pipeline_matmul_id_f32;
-    vk_matmul_pipeline pipeline_matmul_id_f16;
-    vk_matmul_pipeline pipeline_matmul_id_f16_f32;
+    vk_matmul_pipeline2 pipeline_matmul_id_f16;
+    vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
 
-    vk_matmul_pipeline pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
+    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
 
     vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
     vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
@@ -1325,6 +1340,18 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
     return {64, 64};
 };
 
+static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
+    // Needs to be kept up to date on shader changes
+    const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
+    const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
+    const uint32_t warps = warptile[0] / device->subgroup_size;
+
+    const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
+    const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
+    const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
+
+    return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
+}
 
 static void ggml_vk_load_shaders(vk_device& device) {
     VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
@@ -1382,12 +1409,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
         m_align =  64;
         s_align =  32;
     } else {
-        l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
-        m_warptile = { 128,  64,  64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
-        s_warptile = { subgroup_size_16,  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
-        l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
-        m_warptile_mmq = { 128,  64,  64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
-        s_warptile_mmq = { subgroup_size_16,  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
+        // Matrix cores require different warp group sizes
+        const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
+        const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
+        const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
+        const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
+        const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
+        const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
+        const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
+        const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
+        const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
+
+        l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
+        m_warptile = { 128,  64,  64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
+        s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
+
+        l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
+        m_warptile_mmq = { 128,  64,  64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
+        s_warptile_mmq = { subgroup_size_16, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
+
         l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
         m_mmq_wg_denoms = m_wg_denoms = { 64,  64, 1 };
         s_mmq_wg_denoms = s_wg_denoms = { 32,  32, 1 };
@@ -1428,25 +1468,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
             // assert mul_mat_mat_id shaders will fit.
             GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
         }
+        // Disable medium and large matrix multiplication if not enough shared memory is available
+        // Check mmq warptiles as the largest configuration
+        // Throw an error if not enough for any matrix multiplication is available
+        if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
+            std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
+            throw std::runtime_error("Shared memory size too small for matrix multiplication.");
+        } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
+            device->mul_mat_m = false;
+            device->mul_mat_l = false;
+        } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
+            device->mul_mat_l = false;
+        }
+
+        // Disable mul_mat_id if not enough shared memory is available
+        if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
+            device->mul_mat_id_s = false;
+            device->mul_mat_id_m = false;
+            device->mul_mat_id_l = false;
+        } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
+            device->mul_mat_id_m = false;
+            device->mul_mat_id_l = false;
+        } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
+            device->mul_mat_id_l = false;
+        }
     }
 
     device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
     device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
 
     device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
-    device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
 
     std::vector<std::future<void>> compiles;
     auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false) {
@@ -1543,119 +1594,191 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 
         CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
-
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
+
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
 #undef CREATE_MM
 #undef CREATE_MM2
     } else
-#endif
-    if (device->fp16) {
+#endif  // defined(VK_NV_cooperative_matrix2)
+    if (device->coopmat_support) {
         // Create 6 variants, {s,m,l}x{unaligned,aligned}
-#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \
-
-        CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
-
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
+#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+        if (device->mul_mat ## ID ## _l) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _m) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _s) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _l) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \
+        if (device->mul_mat ## ID ## _m) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \
+        if (device->mul_mat ## ID ## _s) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \
+
+        // Create 2 variants, {f16,f32} accumulator
+#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+        CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+        CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+
+        CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
         // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
-        if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
-            CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
-            CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
-            CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
-
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
+        if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
+            CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+            CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+            CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+        }
+#undef CREATE_MM
+    } else if (device->fp16) {
+        // Create 6 variants, {s,m,l}x{unaligned,aligned}
+#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+        if (device->mul_mat ## ID ## _l) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _m) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _s) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _l) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \
+        if (device->mul_mat ## ID ## _m) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \
+        if (device->mul_mat ## ID ## _s) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \
+
+        CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+
+        // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
+        if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
+            CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+            CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+            CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         }
 #undef CREATE_MM
     } else {
         // Create 6 variants, {s,m,l}x{unaligned,aligned}
-#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \
-        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \
-
-        CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
-
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
-        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
+#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+        if (device->mul_mat ## ID ## _l) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _m) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _s) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _l) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align);   \
+        if (device->mul_mat ## ID ## _m) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align);   \
+        if (device->mul_mat ## ID ## _s) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \
+
+        CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
+
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
         // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
-        if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
-            CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
-            CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
-            CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
-
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
-            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
+        if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
+            CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+            CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+            CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
+
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         }
+#undef CREATE_MM2
 #undef CREATE_MM
     }
 
@@ -1851,8 +1974,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
         bool fp16_compute = false;
         bool maintenance4_support = false;
         bool sm_builtins = false;
+        bool amd_shader_core_properties2 = false;
         bool pipeline_robustness = false;
         bool coopmat2_support = false;
+        device->coopmat_support = false;
 
         // Check if maintenance4 is supported
         for (const auto& properties : ext_props) {
@@ -1864,10 +1989,18 @@ static vk_device ggml_vk_get_device(size_t idx) {
                 fp16_compute = true;
             } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
                 sm_builtins = true;
+            } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
+                amd_shader_core_properties2 = true;
             } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
                 pipeline_robustness = true;
+            } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
+                       !getenv("GGML_VK_DISABLE_COOPMAT")) {
+                device->coopmat_support = true;
+                device->coopmat_m = 0;
+                device->coopmat_n = 0;
+                device->coopmat_k = 0;
             } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
-                       !getenv("GGML_VULKAN_DISABLE_COOPMAT2")) {
+                       !getenv("GGML_VK_DISABLE_COOPMAT2")) {
                 coopmat2_support = true;
             }
         }
@@ -1876,11 +2009,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
         vk::PhysicalDeviceMaintenance3Properties props3;
         vk::PhysicalDeviceMaintenance4Properties props4;
         vk::PhysicalDeviceSubgroupProperties subgroup_props;
+        vk::PhysicalDeviceDriverProperties driver_props;
         vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
+        vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
         props2.pNext = &props3;
         props3.pNext = &subgroup_props;
+        subgroup_props.pNext = &driver_props;
 
-        VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;
+        VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
 
         if (maintenance4_support) {
             last_struct->pNext = (VkBaseOutStructure *)&props4;
@@ -1890,6 +2026,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
             last_struct->pNext = (VkBaseOutStructure *)&sm_props;
             last_struct = (VkBaseOutStructure *)&sm_props;
         }
+        if (amd_shader_core_properties2) {
+            last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
+            last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
+        }
 
 #if defined(VK_NV_cooperative_matrix2)
         vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
@@ -1905,7 +2045,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
 
         if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
-            device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
+            device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
         } else if (maintenance4_support) {
             device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
         } else {
@@ -1917,15 +2057,22 @@ static vk_device ggml_vk_get_device(size_t idx) {
         device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
         if (sm_builtins) {
             device->shader_core_count = sm_props.shaderSMCount;
+        } else if (amd_shader_core_properties2) {
+            device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
         } else {
             device->shader_core_count = 0;
         }
 
-        const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
-        const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
+        const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
 
         device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
 
+        if (device->vendor_id == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
+            // Intel drivers don't support coopmat properly yet
+            // Only RADV supports coopmat properly on AMD
+            device->coopmat_support = false;
+        }
+
         std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
 
         // Try to find a non-graphics compute queue and transfer-focused queues
@@ -1976,6 +2123,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device_extensions.push_back("VK_EXT_pipeline_robustness");
         }
 
+        VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
+        coopmat_features.pNext = nullptr;
+        coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
+        coopmat_features.cooperativeMatrix = VK_FALSE;
+
+        if (device->coopmat_support) {
+            last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
+            last_struct = (VkBaseOutStructure *)&coopmat_features;
+        }
+
 #if defined(VK_NV_cooperative_matrix2)
         VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
         coopmat2_features.pNext = nullptr;
@@ -1993,6 +2150,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
 
+        device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
+
         if (coopmat2_support) {
 #if defined(VK_NV_cooperative_matrix2)
             if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
@@ -2083,6 +2242,74 @@ static vk_device ggml_vk_get_device(size_t idx) {
         if (device->fp16) {
             device_extensions.push_back("VK_KHR_shader_float16_int8");
         }
+
+        if (device->coopmat_support) {
+            // Query supported shapes
+            std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
+
+            PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
+                (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
+
+            uint32_t cm_props_num;
+
+            pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
+
+            cm_props.resize(cm_props_num);
+
+            for (auto& prop : cm_props) {
+                prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
+            }
+
+            pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
+
+            VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
+
+            for (auto& prop : cm_props) {
+                VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
+
+                if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
+                    (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
+                    (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
+                ) {
+                    if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
+                        (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
+                        // coopmat sizes not set yet
+                        if (device->coopmat_m == 0) {
+                            device->coopmat_acc_f32_support = true;
+                            device->coopmat_m = prop.MSize;
+                            device->coopmat_n = prop.NSize;
+                            device->coopmat_k = prop.KSize;
+                        } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
+                            // Only enable if shape is identical
+                            device->coopmat_acc_f32_support = true;
+                        }
+                    } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
+                               (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
+                        // coopmat sizes not set yet
+                        if (device->coopmat_m == 0) {
+                            device->coopmat_acc_f16_support = true;
+                            device->coopmat_m = prop.MSize;
+                            device->coopmat_n = prop.NSize;
+                            device->coopmat_k = prop.KSize;
+                        } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
+                            // Only enable if shape is identical
+                            device->coopmat_acc_f16_support = true;
+                        }
+                    }
+                }
+            }
+
+            if (device->coopmat_m == 0) {
+                // No suitable matmul mode found
+                GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
+                device->coopmat_support = false;
+            }
+        }
+
+        if (device->coopmat_support) {
+            device_extensions.push_back("VK_KHR_cooperative_matrix");
+        }
+
         device->name = GGML_VK_NAME + std::to_string(idx);
 
         device_create_info = {
@@ -2098,6 +2325,37 @@ static vk_device ggml_vk_get_device(size_t idx) {
         ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
 
         // Shaders
+        // Disable matmul tile sizes early if performance low or not supported
+        switch (device->vendor_id) {
+#ifndef GGML_VULKAN_RUN_TESTS
+        case VK_VENDOR_ID_AMD:
+        case VK_VENDOR_ID_INTEL:
+            device->mul_mat_l = false;
+            device->mul_mat_m = true;
+            device->mul_mat_s = true;
+            device->mul_mat_id_l = false;
+            device->mul_mat_id_m = true;
+            device->mul_mat_id_s = true;
+            break;
+        case VK_VENDOR_ID_APPLE:
+            device->mul_mat_l = false;
+            device->mul_mat_m = true;
+            device->mul_mat_s = false;
+            device->mul_mat_id_l = false;
+            device->mul_mat_id_m = true;
+            device->mul_mat_id_s = false;
+            break;
+#endif
+        default:
+            device->mul_mat_l = true;
+            device->mul_mat_m = true;
+            device->mul_mat_s = true;
+            device->mul_mat_id_l = true;
+            device->mul_mat_id_m = true;
+            device->mul_mat_id_s = true;
+            break;
+        }
+
         ggml_vk_load_shaders(device);
 
         if (!device->single_queue) {
@@ -2155,15 +2413,24 @@ static void ggml_vk_print_gpu_info(size_t idx) {
 
     bool fp16_storage = false;
     bool fp16_compute = false;
+    bool coopmat_support = false;
 
     for (auto properties : ext_props) {
         if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
             fp16_storage = true;
         } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
             fp16_compute = true;
+        } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
+            coopmat_support = true;
         }
     }
 
+    if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && driver_props.driverID == vk::DriverId::eAmdProprietary)) {
+        // Intel drivers don't support coopmat properly yet
+        // Only RADV supports coopmat properly on AMD
+        coopmat_support = false;
+    }
+
     const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
     bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
 
@@ -2186,13 +2453,28 @@ static void ggml_vk_print_gpu_info(size_t idx) {
     vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
     vk11_features.pNext = &vk12_features;
 
+    // Pointer to the last chain element
+    VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
+
+    VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
+    coopmat_features.pNext = nullptr;
+    coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
+    coopmat_features.cooperativeMatrix = VK_FALSE;
+
+    if (coopmat_support) {
+        last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
+        last_struct = (VkBaseOutStructure *)&coopmat_features;
+    }
+
     vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
 
     fp16 = fp16 && vk12_features.shaderFloat16;
 
+    coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
+
     std::string device_name = props2.properties.deviceName.data();
-    GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu\n",
-              idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size);
+    GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %d\n",
+              idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, coopmat_support);
 
     if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
         GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -2428,7 +2710,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
         return ctx->device->pipeline_matmul_f32_f16;
     }
-    if (prec == GGML_PREC_DEFAULT && ctx->device->coopmat2) {
+    if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) {
         if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
             return ctx->device->pipeline_matmul_f16_f32.f16acc;
         }
@@ -2469,7 +2751,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
         assert(src1_type == GGML_TYPE_F16);
         return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
     }
-    return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
+    return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
 }
 
 static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@@ -2498,16 +2780,25 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
     return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
 }
 
-static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
+static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
     VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
     if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
         return ctx->device->pipeline_matmul_id_f32;
     }
-    if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
-        return ctx->device->pipeline_matmul_id_f16_f32;
-    }
-    if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
-        return ctx->device->pipeline_matmul_id_f16;
+    if (prec == GGML_PREC_DEFAULT && ctx->device->fp16) {
+        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
+        }
+        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+            return ctx->device->pipeline_matmul_id_f16.f16acc;
+        }
+    } else {
+        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
+        }
+        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+            return ctx->device->pipeline_matmul_id_f16.f32acc;
+        }
     }
 
     GGML_ASSERT(src1_type == GGML_TYPE_F32);
@@ -2529,7 +2820,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
             return nullptr;
     }
 
-    return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
+    return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
 }
 
 static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@@ -3119,62 +3410,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
     return split_k;
 }
 
-static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
-    if (m <= 32 || n <= 32) {
-        return aligned ? mmp->a_s : mmp->s;
-    }
-    return aligned ? mmp->a_m : mmp->m;
-
-    GGML_UNUSED(ctx);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
-    return aligned ? mmp->a_m : mmp->m;
-
-    GGML_UNUSED(ctx);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
-    return aligned ? mmp->a_s : mmp->s;
-
-    GGML_UNUSED(ctx);
-}
-
-static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
+static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type type_a) {
     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
-    switch (ctx->device->vendor_id) {
-    case VK_VENDOR_ID_AMD:
-        return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
-    case VK_VENDOR_ID_APPLE:
-        return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
-    case VK_VENDOR_ID_INTEL:
-        return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
-    default:
-        break;
-    }
 
     if (ctx->device->coopmat2) {
-        if ((m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) {
+        if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
             return aligned ? mmp->a_l : mmp->l;
         }
-        if ((m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) {
+        if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
             return aligned ? mmp->a_m : mmp->m;
         }
         return aligned ? mmp->a_s : mmp->s;
     }
 
-    if (m <= 32 || n <= 32) {
+    if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
         return aligned ? mmp->a_s : mmp->s;
     }
-    if (m <= 64 || n <= 64) {
+    if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
         return aligned ? mmp->a_m : mmp->m;
     }
     return aligned ? mmp->a_l : mmp->l;
 }
 
-static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
+static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type type_a) {
     VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
-    return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
+    return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, type_a)->align;
 }
 
 static void ggml_vk_matmul(
@@ -3201,6 +3461,33 @@ static void ggml_vk_matmul(
     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
 }
 
+static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
+    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
+
+    if (ctx->device->coopmat2) {
+        if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
+            return aligned ? mmp->a_l : mmp->l;
+        }
+        if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
+            return aligned ? mmp->a_m : mmp->m;
+        }
+        return aligned ? mmp->a_s : mmp->s;
+    }
+
+    if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
+        return aligned ? mmp->a_s : mmp->s;
+    }
+    if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
+        return aligned ? mmp->a_m : mmp->m;
+    }
+    return aligned ? mmp->a_l : mmp->l;
+}
+
+static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
+    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
+    return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
+}
+
 static void ggml_vk_matmul_id(
         ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
         vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
@@ -3350,10 +3637,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     const int y_ne = ne11 * ne10;
     const int d_ne = ne11 * ne01;
 
-    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
+    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, src0->type));
     const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
 
-    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
+    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, src0->type);
 
     const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
 
@@ -3904,7 +4191,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
 
     const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 
-    vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
+    vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
 
     const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
     const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
@@ -3920,10 +4207,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const uint64_t y_ne = ne11 * ne10;
     const uint64_t d_ne = ne21 * ne20;
 
-    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1));
+    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
     const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
 
-    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned);
+    vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
 
     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@@ -5504,19 +5791,27 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
     for (size_t i = 0; i < x_ne; i++) {
         if (std::is_same<float, X_TYPE>()) {
             x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+            // x[i] = 1.0f;
+            // x[i] = i + 1;
+            // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
         } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
             x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
+            // x[i] = ggml_fp32_to_fp16(1.0f);
+            // x[i] = ggml_fp32_to_fp16(i + 1);
+            // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
         } else {
             GGML_ABORT("fatal error");
         }
     }
     for (size_t i = 0; i < y_ne; i++) {
         if (std::is_same<float, Y_TYPE>()) {
-            // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
-            y[i] = (i % k == i / k) ? 1.0f : 0.0f;
+            y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+            // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
+            // y[i] = i + 1;
         } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
-            // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
-            y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
+            y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
+            // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
+            // y[i] = ggml_fp32_to_fp16(i + 1);
         } else {
             GGML_ABORT("fatal error");
         }
@@ -5600,7 +5895,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
         double err = std::fabs(d[i] - d_chk[i]);
         avg_err += err;
 
-        if (err > 0.05f && first_err_n == -1) {
+        if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
             first_err_b = i / (m * n);
             first_err_n = (i % (m * n)) / m;
             first_err_m = (i % (m * n)) % m;
@@ -5613,12 +5908,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
 
     std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
 
-    if (avg_err > 0.1) {
+    if (avg_err > 0.1 || std::isnan(avg_err)) {
         std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
         std::cerr << "Actual result: " << std::endl << std::endl;
         ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
-        std::cerr << std::endl;
-        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
         std::cerr << "Expected result: " << std::endl << std::endl;
         ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
 
@@ -5801,13 +6094,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
     vk_pipeline p;
     std::string shname;
     if (shader_size == 0) {
-        p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
+        p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
         shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
     } else if (shader_size == 1) {
-        p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
+        p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
         shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
     } else if (shader_size == 2) {
-        p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
+        p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
         shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
     } else {
         GGML_ASSERT(0);
@@ -5817,13 +6110,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
 
     if (k != kpad) {
         if (shader_size == 0) {
-            p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
+            p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
             shname = std::string(ggml_type_name(quant)) + "_S";
         } else if (shader_size == 1) {
-            p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
+            p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
             shname = std::string(ggml_type_name(quant)) + "_M";
         } else if (shader_size == 2) {
-            p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
+            p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
             shname = std::string(ggml_type_name(quant)) + "_L";
         } else {
             GGML_ASSERT(0);
@@ -5982,105 +6275,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
 
 static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
 #if defined(GGML_VULKAN_RUN_TESTS)
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
-    ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
-
-    ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
-
-    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
-    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
-    ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
-    // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
-    // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
-    // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
-    // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
-
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
-    ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
-
-    std::cerr << std::endl;
-
     const std::vector<size_t> vals {
+        512, 512, 128,
+        128, 512, 512,
+        4096, 512, 4096,
+        11008, 512, 4096,
+        4096, 512, 11008,
+        32000, 512, 4096,
         8, 8, 8,
         100, 46, 576,
         623, 111, 128,
@@ -6093,15 +6294,6 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
         49, 49, 128,
         128, 49, 49,
         4096, 49, 4096,
-        11008, 49, 4096,
-        4096, 49, 11008,
-        32000, 49, 4096,
-        512, 512, 128,
-        128, 512, 512,
-        4096, 512, 4096,
-        11008, 512, 4096,
-        4096, 512, 11008,
-        32000, 512, 4096,
     };
     const size_t num_it = 100;
 
@@ -6109,10 +6301,45 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
         ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
         ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
         ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
-        // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
-        // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
-        // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
-        std::cerr << std::endl;
+        std::cerr << '\n';
+        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
+        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
+        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
+        std::cerr << '\n';
+        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
+        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
+        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
+        std::cerr << '\n' << std::endl;
+
+        if (vals[i + 2] % 32 == 0) {
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
+            std::cerr << '\n';
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
+            std::cerr << '\n';
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
+            std::cerr << '\n' << std::endl;
+        }
+
+        if (vals[i + 2] % 256 == 0) {
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
+            std::cerr << '\n';
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
+            std::cerr << '\n';
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
+            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
+            std::cerr << '\n' << std::endl;
+        }
     }
 
     GGML_ABORT("fatal error");
@@ -7200,8 +7427,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_MUL_MAT_ID:
             {
                 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                if (op->op == GGML_OP_MUL_MAT_ID &&
-                    ggml_vk_get_device(ctx->device)->properties.limits.maxComputeSharedMemorySize < 32768) {
+                const vk_device& device = ggml_vk_get_device(ctx->device);
+                if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
                     // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
                     return false;
                 }
index 2ff5c430519bce79be5ee672649252fbb29c04d2..48122cbef906ebb92e59322df8c602594541df3d 100644 (file)
@@ -7,6 +7,12 @@
 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
 #endif
 
+#ifdef COOPMAT
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+#endif
+
 #ifdef MUL_MAT_ID
 #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
 #endif
@@ -57,6 +63,7 @@ layout (push_constant) uniform parameter
 #endif
 } p;
 
+layout (constant_id = 0) const uint BLOCK_SIZE = 64;
 layout (constant_id = 1) const uint BM = 64;
 layout (constant_id = 2) const uint BN = 64;
 layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant
@@ -65,13 +72,26 @@ layout (constant_id = 5) const uint WN = 32;
 layout (constant_id = 6) const uint WMITER = 2;
 layout (constant_id = 7) const uint TM = 4;
 layout (constant_id = 8) const uint TN = 2;
-layout (constant_id = 9) const uint WARP = 32;
+layout (constant_id = 9) const uint TK = 1;  // Only needed for coopmat
+layout (constant_id = 10) const uint WARP = 32;
+
+#ifdef COOPMAT
+#define SHMEM_STRIDE (BK + 8)
+#else
+#define SHMEM_STRIDE (BK + 1)
+#endif
 
-shared FLOAT_TYPE buf_a[BM * (BK+1)];
-shared FLOAT_TYPE buf_b[BN * (BK+1)];
+shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
+shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
 
 #ifdef MUL_MAT_ID
 shared u16vec2 row_ids[3072];
+#endif // MUL_MAT_ID
+
+#define NUM_WARPS (BLOCK_SIZE / WARP)
+
+#ifdef COOPMAT
+shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
 #endif
 
 void main() {
@@ -98,17 +118,32 @@ void main() {
     const uint ik = gl_WorkGroupID.x / blocks_m;
     const uint ic = gl_WorkGroupID.y;
 
-    const uint warp_i = gl_LocalInvocationID.x / WARP;
-    const uint warp_r = warp_i % (BM / WM);
-    const uint warp_c = warp_i / (BM / WM);
-
     const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
     const uint WSUBM = WM / WMITER;
     const uint WSUBN = WN / WNITER;
 
+#ifdef COOPMAT
+    const uint warp_i = gl_SubgroupID;
+
+    const uint tiw = gl_SubgroupInvocationID;
+
+    const uint cms_per_row = WM / TM;
+    const uint cms_per_col = WN / TN;
+
+    const uint storestride = WARP / TM;
+    const uint store_r = tiw % TM;
+    const uint store_c = tiw / TM;
+#else
+    const uint warp_i = gl_LocalInvocationID.x / WARP;
+
     const uint tiw = gl_LocalInvocationID.x % WARP;
+
     const uint tiwr = tiw % (WSUBM / TM);
     const uint tiwc = tiw / (WSUBM / TM);
+#endif
+
+    const uint warp_r = warp_i % (BM / WM);
+    const uint warp_c = warp_i / (BM / WM);
 
     const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
     const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
@@ -156,21 +191,31 @@ void main() {
     uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
 #endif
 
-    float sums[WMITER * TM * WNITER * TN];
+#ifdef COOPMAT
+    coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
+    coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
+    coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
+
+    [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
+        sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
+    }
+#else
+    ACC_TYPE sums[WMITER * TM * WNITER * TN];
     FLOAT_TYPE cache_a[WMITER * TM];
     FLOAT_TYPE cache_b[WNITER * TN];
 
     [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
-        sums[i] = 0.0f;
+        sums[i] = ACC_TYPE(0.0f);
     }
+#endif
 
-    [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
+    for (uint block = start_k; block < end_k; block += BK) {
         [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
 
 #if defined(DATA_A_F32) || defined(DATA_A_F16)
 #if LOAD_VEC_A == 8
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
             buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx][0].x);
             buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
             buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
@@ -181,21 +226,21 @@ void main() {
             buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
 #elif LOAD_VEC_A == 4
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
             buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx].x);
             buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
             buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
             buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
 #else
             if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
-                buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
+                buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
             } else {
-                buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
+                buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
             }
 #endif
 #elif defined(DATA_A_Q4_0)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
 
             const uint ib = idx / 16;
             const uint iqs = idx & 0xF;
@@ -208,7 +253,7 @@ void main() {
             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
 #elif defined(DATA_A_Q4_1)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
 
             const uint ib = idx / 16;
             const uint iqs = idx & 0xF;
@@ -222,7 +267,7 @@ void main() {
             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
 #elif defined(DATA_A_Q5_0)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
 
             const uint ib = idx / 16;
             const uint iqs = idx & 0xF;
@@ -237,7 +282,7 @@ void main() {
             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
 #elif defined(DATA_A_Q5_1)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
 
             const uint ib = idx / 16;
             const uint iqs = idx & 0xF;
@@ -253,7 +298,7 @@ void main() {
             buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
 #elif defined(DATA_A_Q8_0)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
             const uint ib = idx / 16;
             const uint iqs = (idx & 0xF) * 2;
@@ -265,7 +310,7 @@ void main() {
             buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
 #elif defined(DATA_A_Q2_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
             const uint ib = idx / 128;                         // 2 values per idx
             const uint iqs = idx % 128;                        // 0..127
@@ -284,7 +329,7 @@ void main() {
             buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
 #elif defined(DATA_A_Q3_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
             const uint ib = idx / 128;                   // 2 values per idx
             const uint iqs = idx % 128;                  // 0..127
@@ -308,7 +353,7 @@ void main() {
             buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
 #elif defined(DATA_A_Q4_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
             const uint ib = idx / 128;                 // 2 values per idx
             const uint iqs = idx % 128;                // 0..127
@@ -320,15 +365,20 @@ void main() {
 
             const vec2 loadd = vec2(data_a[ib].d);
 
-            uint8_t sc;
-            uint8_t mbyte;
-            if (is < 4) {
-                sc    = uint8_t(data_a[ib].scales[is    ] & 63);
-                mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
-            } else {
-                sc    = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
-                mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
-            }
+            const uint scidx0 = (is < 4) ? is : (is + 4);
+            const uint scidx1 = (is < 4) ? is : (is - 4);
+            const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
+            const uint scidxshift1 = (is < 4) ? 0 : 2;
+            const uint mbidx0 = is + 4;
+            const uint mbidx1 = (is < 4) ? is + 4 : is;
+            const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
+            const uint mbidxshift0 = (is < 4) ? 0 : 4;
+            const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
+            const uint mbidxshift1 = (is < 4) ? 0 : 2;
+
+            const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
+            const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
+
             const float d = loadd.x * sc;
             const float m = -loadd.y * mbyte;
 
@@ -336,7 +386,7 @@ void main() {
             buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
 #elif defined(DATA_A_Q5_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
             const uint ib = idx / 128;                 // 2 values per idx
             const uint iqs = idx % 128;                // 0..127
@@ -351,15 +401,20 @@ void main() {
 
             const vec2 loadd = vec2(data_a[ib].d);
 
-            uint8_t sc;
-            uint8_t mbyte;
-            if (is < 4) {
-                sc    = uint8_t(data_a[ib].scales[is    ] & 63);
-                mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
-            } else {
-                sc    = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
-                mbyte = uint8_t((data_a[ib].scales[is + 4] >>  4) | ((data_a[ib].scales[is    ] >> 6) << 4));
-            }
+            const uint scidx0 = (is < 4) ? is : (is + 4);
+            const uint scidx1 = (is < 4) ? is : (is - 4);
+            const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
+            const uint scidxshift1 = (is < 4) ? 0 : 2;
+            const uint mbidx0 = is + 4;
+            const uint mbidx1 = (is < 4) ? is + 4 : is;
+            const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
+            const uint mbidxshift0 = (is < 4) ? 0 : 4;
+            const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
+            const uint mbidxshift1 = (is < 4) ? 0 : 2;
+
+            const uint8_t sc    = uint8_t((data_a[ib].scales[scidx0] & 0xF)                         | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
+            const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
+
             const float d = loadd.x * sc;
             const float m = -loadd.y * mbyte;
 
@@ -367,7 +422,7 @@ void main() {
             buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
 #elif defined(DATA_A_Q6_K)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
             const uint ib = idx / 128;                  // 2 values per idx
             const uint iqs = idx % 128;                 // 0..127
@@ -386,7 +441,7 @@ void main() {
             buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
 #elif defined(DATA_A_IQ4_NL)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
-            const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
 
             const uint ib = idx / 16;
             const uint iqs = idx & 0xF;
@@ -407,7 +462,7 @@ void main() {
 #else
             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
 #endif
-            const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
+            const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
             buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
             buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
             buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
@@ -423,24 +478,24 @@ void main() {
 #else
             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
 #endif
-            const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
+            const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
             buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
             buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
             buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
             buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
 #elif !MUL_MAT_ID
             if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
+                buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
             } else {
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
+                buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
             }
 #else
             const uint row_i = ic * BN + loadc_b + l;
             if (row_i < _ne1) {
                 const u16vec2 row_idx = row_ids[row_i];
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
+                buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
             } else {
-                buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
+                buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
             }
 #endif
         }
@@ -450,16 +505,30 @@ void main() {
         pos_a += BK / LOAD_VEC_A;
         pos_b += BK / LOAD_VEC_B;
 
-        for (uint i = 0; i < BK; i++) {
+#ifdef COOPMAT
+        [[unroll]] for (uint i = 0; i < BK; i += TK) {
+            [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+                // Load from shared into cache
+                coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
+
+                [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+                    coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
+
+                    sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
+                }
+            }
+        }
+#else
+        [[unroll]] for (uint i = 0; i < BK; i++) {
             // Load from shared into cache
             [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
                 [[unroll]] for (uint j = 0; j < TM; j++) {
-                    cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
+                    cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
                 }
             }
             [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
                 [[unroll]] for (uint j = 0; j < TN; j++) {
-                    cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
+                    cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
                 }
             }
 
@@ -468,12 +537,13 @@ void main() {
                     [[unroll]] for (uint cc = 0; cc < TN; cc++) {
                         [[unroll]] for (uint cr = 0; cr < TM; cr++) {
                             const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
-                            sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
+                            sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
                         }
                     }
                 }
             }
         }
+#endif
 
         barrier();
     }
@@ -485,6 +555,54 @@ void main() {
     const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
 #endif
 
+#ifdef COOPMAT
+#ifdef MUL_MAT_ID
+    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+            coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+            [[unroll]] for (uint col = 0; col < BN; col += storestride) {
+                const uint row_i = dc + cm_col * TN + col + store_c;
+                if (row_i >= _ne1) break;
+
+                const u16vec2 row_idx = row_ids[row_i];
+
+                data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+            }
+        }
+    }
+#else
+    const bool is_aligned = p.stride_d % 4 == 0;  // Assumption: D_TYPE == float
+
+    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+            const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
+
+            if (is_aligned && is_in_bounds) {
+                // Full coopMat is within bounds and stride_d is aligned with 16B
+                coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
+                coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
+            } else if (is_in_bounds) {
+                // Full coopMat is within bounds, but stride_d is not aligned
+                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+                    data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+                }
+            } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
+                // Partial coopMat is within bounds
+                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+                    if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
+                        data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+                    }
+                }
+            }
+        }
+    }
+#endif // MUL_MAT_ID
+#else
     [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
         [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
 
@@ -496,7 +614,7 @@ void main() {
                 if (row_i >= _ne1) break;
 
                 const u16vec2 row_idx = row_ids[row_i];
-#endif
+#endif // MUL_MAT_ID
                 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
 #ifdef MUL_MAT_ID
                     data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
@@ -504,9 +622,10 @@ void main() {
                     if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
                         data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
                     }
-#endif
+#endif // MUL_MAT_ID
                 }
             }
         }
     }
+#endif // COOPMAT
 }
index 4716e2c80869bfdec9e06103ef01504ef65fdb32..c4967ed658eb2674d6cd8899c47d8ebd670358b4 100644 (file)
@@ -60,6 +60,7 @@ const std::vector<std::string> type_names = {
     "iq4_nl"
 };
 
+namespace {
 void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
 #ifdef _WIN32
     HANDLE stdout_read, stdout_write;
@@ -198,8 +199,8 @@ static uint32_t compile_count = 0;
 static std::mutex compile_count_mutex;
 static std::condition_variable compile_count_cond;
 
-void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
-    std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
+void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
+    std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
     std::string out_fname = join_paths(output_dir, name + ".spv");
     std::string in_path = join_paths(input_dir, in_fname);
 
@@ -258,7 +259,7 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s
 }
 
 static std::vector<std::future<void>> compiles;
-void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat2 = false, bool f16acc = false) {
+void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
     {
         // wait until fewer than N compiles are in progress.
         // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
@@ -269,10 +270,10 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
         }
         compile_count++;
     }
-    compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat2, f16acc));
+    compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
 }
 
-void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
+void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
     std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
     std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
     std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
@@ -291,14 +292,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
 
     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
 
+    if (coopmat) {
+        base_dict["COOPMAT"] = "1";
+    }
+
+    base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
+
     std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
 
     // Shaders with f16 B_TYPE
-    string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat2, f16acc);
-    string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
 
-    string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
-    string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
 
     for (const auto& tname : type_names) {
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);
@@ -307,12 +314,12 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat2, bool f16acc) {
         // For aligned matmul loads
         std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
 
-        string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
-        string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
+        string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
+        string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
 
         if (tname != "f16" && tname != "f32") {
-            string_to_spv(shader_name + "_" + tname + "_f16", source_name,          merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat2, f16acc);
-            string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f16", source_name,          merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
         }
     }
 }
@@ -322,25 +329,24 @@ void process_shaders() {
     std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
 
     // matmul
-    for (const auto& fp16 : {false, true}) {
-        for (const auto& matmul_id : {false, true}) {
-            for (const auto& coopmat2 : {false, true}) {
-                for (const auto& f16acc : {false, true}) {
-#if !defined(VK_NV_cooperative_matrix2)
-                    if (coopmat2) {
-                        continue;
-                    }
+    for (const auto& matmul_id : {false, true}) {
+        // No coopmats
+        // fp32
+        matmul_shaders(false, matmul_id, false, false, false);
+
+        // fp16, fp32acc and fp16acc
+        matmul_shaders(true, matmul_id, false, false, false);
+        matmul_shaders(true, matmul_id, false, false, true);
+
+        // Coopmat, fp32acc and fp16acc
+        matmul_shaders(true, matmul_id, true, false, false);
+        matmul_shaders(true, matmul_id, true, false, true);
+
+#if defined(VK_NV_cooperative_matrix2)
+        // Coopmat2, fp32acc and fp16acc
+        matmul_shaders(true, matmul_id, false, true, false);
+        matmul_shaders(true, matmul_id, false, true, true);
 #endif
-                    if (coopmat2 && !fp16) {
-                        continue;
-                    }
-                    if (!coopmat2 && f16acc) {
-                        continue;
-                    }
-                    matmul_shaders(fp16, matmul_id, coopmat2, f16acc);
-                }
-            }
-        }
     }
 
 #if defined(VK_NV_cooperative_matrix2)
@@ -355,11 +361,11 @@ void process_shaders() {
 
             if (tname == "f16") {
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, true, f16acc);
+                    merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
             } else {
                 std::string data_a_key = "DATA_A_" + to_uppercase(tname);
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, true, f16acc);
+                    merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
             }
         }
     }
@@ -524,6 +530,7 @@ void write_output_files() {
     fclose(hdr);
     fclose(src);
 }
+}
 
 int main(int argc, char** argv) {
     std::map<std::string, std::string> args;