]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Vulkan: Add DP4A MMQ and Q8_1 quantization shader (llama/12135)
author0cc4m <redacted>
Mon, 31 Mar 2025 12:37:01 +0000 (14:37 +0200)
committerGeorgi Gerganov <redacted>
Wed, 2 Apr 2025 12:51:57 +0000 (15:51 +0300)
* Vulkan: Add DP4A MMQ and Q8_1 quantization shader

* Add q4_0 x q8_1 matrix matrix multiplication support

* Vulkan: Add int8 coopmat MMQ support

* Vulkan: Add q4_1, q5_0 and q5_1 quants, improve integer dot code

* Add GL_EXT_integer_dot_product check

* Remove ggml changes, fix mmq pipeline picker

* Remove ggml changes, restore Intel coopmat behaviour

* Fix glsl compile attempt when integer vec dot is not supported

* Remove redundant code, use non-saturating integer dot, enable all matmul sizes for mmq

* Remove redundant comment

* Fix integer dot check

* Fix compile issue with unsupported int dot glslc

* Update Windows build Vulkan SDK version

ggml/src/ggml-vulkan/CMakeLists.txt
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/types.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 2615ae1ab7dd2c395bf1629b9fa700ae8927529c..e3c59b75fd5a39bab4a4e048d544f27123b43548 100644 (file)
@@ -69,6 +69,20 @@ if (Vulkan_FOUND)
         add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
     endif()
 
+    # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
+    # If it's not, there will be an error to stderr.
+    # If it's supported, set a define to indicate that we should compile those shaders
+    execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
+                    OUTPUT_VARIABLE glslc_output
+                    ERROR_VARIABLE glslc_error)
+
+    if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
+        message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
+    else()
+        message(STATUS "GL_EXT_integer_dot_product supported by glslc")
+        add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+    endif()
+
     target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
     target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
 
index bc16567dc4524d758fb7b187a69d722d78962cc8..31330e2b29bfe1e15e54d79c0eef59a4ab3b500f 100644 (file)
@@ -234,6 +234,8 @@ struct vk_device_struct {
     bool float_controls_rte_fp16;
     bool subgroup_add;
 
+    bool integer_dot_product;
+
     bool subgroup_size_control;
     uint32_t subgroup_min_size;
     uint32_t subgroup_max_size;
@@ -245,6 +247,12 @@ struct vk_device_struct {
     uint32_t coopmat_m;
     uint32_t coopmat_n;
     uint32_t coopmat_k;
+
+    bool coopmat_int_support;
+    uint32_t coopmat_int_m;
+    uint32_t coopmat_int_n;
+    uint32_t coopmat_int_k;
+
     bool coopmat2;
 
     size_t idx;
@@ -263,10 +271,10 @@ struct vk_device_struct {
     vk_matmul_pipeline pipeline_matmul_f32_f16 {};
     vk_matmul_pipeline2 pipeline_matmul_f16;
     vk_matmul_pipeline2 pipeline_matmul_f16_f32;
-    vk_pipeline pipeline_matmul_split_k_reduce;
 
-    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
     vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
+    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
+    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
 
     vk_matmul_pipeline pipeline_matmul_id_f32 {};
     vk_matmul_pipeline2 pipeline_matmul_id_f16;
@@ -274,6 +282,9 @@ struct vk_device_struct {
 
     vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
 
+    vk_pipeline pipeline_matmul_split_k_reduce;
+    vk_pipeline pipeline_quantize_q8_1;
+
     vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
     vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
     vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@@ -640,6 +651,13 @@ struct vk_op_rwkv_wkv7_push_constants {
     uint32_t H;
 };
 
+struct vk_op_upscale_push_constants {
+    uint32_t ne; uint32_t a_offset; uint32_t d_offset;
+    uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
+    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
+    float sf0; float sf1; float sf2; float sf3;
+};
+
 // Allow pre-recording command buffers
 struct vk_staging_memcpy {
     vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -649,13 +667,6 @@ struct vk_staging_memcpy {
     size_t n;
 };
 
-struct vk_op_upscale_push_constants {
-    uint32_t ne; uint32_t a_offset; uint32_t d_offset;
-    uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
-    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
-    float sf0; float sf1; float sf2; float sf3;
-};
-
 struct vk_context_struct {
     vk_submission * s;
     std::vector<vk_sequence> seqs;
@@ -1598,6 +1609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     // mulmat
     std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
                           l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
+                          l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
                           l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
                           l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
     std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
@@ -1662,6 +1674,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
         m_warptile_mmq = { 128,  64,  64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
         s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
 
+        const uint32_t tm_int_l = device->coopmat_int_support ? device->coopmat_int_m : 4;
+        const uint32_t tm_int_m = device->coopmat_int_support ? device->coopmat_int_m : 4;
+        const uint32_t tm_int_s = device->coopmat_int_support ? device->coopmat_int_m : 2;
+        const uint32_t tn_int_l = device->coopmat_int_support ? device->coopmat_int_n : 4;
+        const uint32_t tn_int_m = device->coopmat_int_support ? device->coopmat_int_n : 2;
+        const uint32_t tn_int_s = device->coopmat_int_support ? device->coopmat_int_n : 2;
+        const uint32_t tk_int_l = device->coopmat_int_support ? device->coopmat_int_k : 1;
+        const uint32_t tk_int_m = device->coopmat_int_support ? device->coopmat_int_k : 1;
+        const uint32_t tk_int_s = device->coopmat_int_support ? device->coopmat_int_k : 1;
+
+        l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_int_l, tn_int_l, tk_int_l, subgroup_size_8 };
+        m_warptile_mmq_int = { 128,  64,  64, 32, subgroup_size_8, 32, 2,     tm_int_m, tn_int_m, tk_int_m, subgroup_size_8 };
+        s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2,       tm_int_s, tn_int_s, tk_int_s, subgroup_size_8 };
+
         l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
         m_mmq_wg_denoms = m_wg_denoms = { 64,  64, 1 };
         s_mmq_wg_denoms = s_wg_denoms = { 32,  32, 1 };
@@ -2000,6 +2026,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
         if (device->mul_mat ## ID ## _s[TYPE]) \
             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \
 
+#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+        if (device->mul_mat ## ID ## _l[TYPE]) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _m[TYPE]) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _s[TYPE]) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
+
         // Create 2 variants, {f16,f32} accumulator
 #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
         CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -2031,6 +2065,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+        if (device->integer_dot_product) {
+            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+        }
+#endif
+
         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
         CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2056,6 +2100,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 #undef CREATE_MM2
+#undef CREATE_MMQ
 #undef CREATE_MM
     } else {
         // Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -2073,6 +2118,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
         if (device->mul_mat ## ID ## _s[TYPE]) \
             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align);   \
 
+#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+        if (device->mul_mat ## ID ## _l[TYPE]) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _m[TYPE]) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
+        if (device->mul_mat ## ID ## _s[TYPE]) \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
+
         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@@ -2099,6 +2152,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+        if (device->integer_dot_product) {
+            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+        }
+#endif
+
         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2132,7 +2195,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     uint32_t rm_stdq = 1;
     uint32_t rm_kq = 2;
     if (device->vendor_id == VK_VENDOR_ID_AMD) {
-        if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN
+        if (device->architecture == AMD_GCN) {
             rm_stdq = 2;
             rm_kq = 4;
         }
@@ -2266,6 +2329,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl_f32",  get_rows_iq4_nl_f32_len,  get_rows_iq4_nl_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
 
     for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
         if (device->subgroup_add && device->subgroup_require_full_support) {
@@ -2452,6 +2516,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         bool pipeline_robustness = false;
         bool coopmat2_support = false;
         device->coopmat_support = false;
+        device->integer_dot_product = false;
 
         for (const auto& properties : ext_props) {
             if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -2477,6 +2542,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
             } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
                        !getenv("GGML_VK_DISABLE_COOPMAT2")) {
                 coopmat2_support = true;
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+            } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
+                       !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
+                device->integer_dot_product = true;
+#endif
             }
         }
 
@@ -2490,6 +2560,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         vk::PhysicalDeviceVulkan11Properties vk11_props;
         vk::PhysicalDeviceVulkan12Properties vk12_props;
         vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
+        vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
 
         props2.pNext = &props3;
         props3.pNext = &subgroup_props;
@@ -2524,6 +2595,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
         }
 #endif
 
+        if (device->integer_dot_product) {
+            last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
+            last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
+        }
+
         device->physical_device.getProperties2(&props2);
         device->properties = props2.properties;
         device->vendor_id = device->properties.vendorID;
@@ -2570,6 +2646,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device->coopmat_support = false;
         }
 
+        device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
+
         std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
 
         // Try to find a non-graphics compute queue and transfer-focused queues
@@ -2662,6 +2740,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device_extensions.push_back("VK_KHR_maintenance4");
         }
 
+        VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
+        shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
+        if (device->integer_dot_product) {
+            last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
+            last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
+            device_extensions.push_back("VK_KHR_shader_integer_dot_product");
+        }
+
         vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
 
         device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -2831,6 +2917,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
                             device->coopmat_acc_f16_support = true;
                         }
                     }
+                } else if ((vk::ComponentTypeKHR)prop.AType      == vk::ComponentTypeKHR::eSint8 &&
+                           (vk::ComponentTypeKHR)prop.BType      == vk::ComponentTypeKHR::eSint8 &&
+                           (vk::ComponentTypeKHR)prop.CType      == vk::ComponentTypeKHR::eSint32 &&
+                           (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
+                           (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
+                           device->coopmat_int_m == 0
+                ) {
+                    device->coopmat_int_support = true;
+                    device->coopmat_int_m = prop.MSize;
+                    device->coopmat_int_n = prop.NSize;
+                    device->coopmat_int_k = prop.KSize;
                 }
             }
 
@@ -2935,25 +3032,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
     vk::PhysicalDevice physical_device = devices[dev_num];
     std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
 
-    vk::PhysicalDeviceProperties2 props2;
-    vk::PhysicalDeviceMaintenance3Properties props3;
-    vk::PhysicalDeviceSubgroupProperties subgroup_props;
-    vk::PhysicalDeviceDriverProperties driver_props;
-    props2.pNext = &props3;
-    props3.pNext = &subgroup_props;
-    subgroup_props.pNext = &driver_props;
-    physical_device.getProperties2(&props2);
-
-    vk_device_architecture arch = get_device_architecture(physical_device);
-    uint32_t default_subgroup_size = get_subgroup_size("", arch);
-    const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
-
-    const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
-
     bool fp16_storage = false;
     bool fp16_compute = false;
     bool coopmat_support = false;
     bool coopmat2_support = false;
+    bool integer_dot_product = false;
 
     for (auto properties : ext_props) {
         if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -2969,27 +3052,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
         } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
                    !getenv("GGML_VK_DISABLE_COOPMAT2")) {
             coopmat2_support = true;
+#endif
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+        } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
+                    !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
+            integer_dot_product = true;
 #endif
         }
     }
 
     const vk_device_architecture device_architecture = get_device_architecture(physical_device);
 
-    if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
-        coopmat_support = false;
-    }
-
     const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
     bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
 
     bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
 
-    vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
+    vk::PhysicalDeviceProperties2 props2;
+    vk::PhysicalDeviceMaintenance3Properties props3;
+    vk::PhysicalDeviceSubgroupProperties subgroup_props;
+    vk::PhysicalDeviceDriverProperties driver_props;
+    vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
+    props2.pNext = &props3;
+    props3.pNext = &subgroup_props;
+    subgroup_props.pNext = &driver_props;
+
+    // Pointer to the last chain element
+    VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
+
+    if (integer_dot_product) {
+        last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
+        last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
+    }
+
+    physical_device.getProperties2(&props2);
 
     VkPhysicalDeviceFeatures2 device_features2;
     device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
     device_features2.pNext = nullptr;
-    device_features2.features = (VkPhysicalDeviceFeatures)device_features;
 
     VkPhysicalDeviceVulkan11Features vk11_features;
     vk11_features.pNext = nullptr;
@@ -3002,7 +3102,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
     vk11_features.pNext = &vk12_features;
 
     // Pointer to the last chain element
-    VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
+    last_struct = (VkBaseOutStructure *)&vk12_features;
 
 #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
     VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
@@ -3014,20 +3114,37 @@ static void ggml_vk_print_gpu_info(size_t idx) {
         last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
         last_struct = (VkBaseOutStructure *)&coopmat_features;
     }
+#endif
+
+    VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
+    shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
+    if (integer_dot_product) {
+        last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
+        last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
+    }
 
     vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
 
     fp16 = fp16 && vk12_features.shaderFloat16;
 
-    coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
-#endif
+    uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
+    const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
+    const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
+
+    integer_dot_product = integer_dot_product
+                       && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
+                       && shader_integer_dot_product_features.shaderIntegerDotProduct;
+
+    coopmat_support = coopmat_support
+                   && coopmat_features.cooperativeMatrix
+                   && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
 
     std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
 
     std::string device_name = props2.properties.deviceName.data();
-    GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
+    GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
               idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
-              props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
+              props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
 
     if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
         GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -3293,6 +3410,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
         }
     }
 
+    // MMQ
+    if (src1_type == GGML_TYPE_Q8_1) {
+        vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
+
+        if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+            return nullptr;
+        }
+
+        return pipelines;
+    }
+
     if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
         return nullptr;
     }
@@ -3585,8 +3713,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
     return s;
 }
 
-
-
 static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
     const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
     const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
@@ -4016,8 +4142,8 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
     return split_k;
 }
 
-static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
-    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
+static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
+    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
 
     if (ctx->device->coopmat2) {
         // Use large shader when the N dimension is greater than the medium shader's tile size
@@ -4042,9 +4168,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
     return aligned ? mmp->a_l : mmp->l;
 }
 
-static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
-    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
-    return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
+static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
+    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
+    return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
 }
 
 static void ggml_vk_matmul(
@@ -4054,7 +4180,7 @@ static void ggml_vk_matmul(
         uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
         uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
         uint32_t padded_n) {
-        VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
+        VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
     ggml_vk_sync_buffers(subctx);
     if (split_k == 1) {
         const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
@@ -4072,7 +4198,7 @@ static void ggml_vk_matmul(
     ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
 }
 
-static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
+static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
     VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
 
     if (ctx->device->coopmat2) {
@@ -4214,6 +4340,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
 }
 
+static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
+    switch(type) {
+        case GGML_TYPE_Q8_1:
+            return ctx->device->pipeline_quantize_q8_1;
+        default:
+            std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
+            GGML_ABORT("fatal error");
+    }
+}
+
+static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
+    VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
+
+    vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
+
+    ggml_vk_sync_buffers(subctx);
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
+}
+
 static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
     std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4265,10 +4410,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
 
     const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 
-    vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
+    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
+
+    // Check for mmq first
+    vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
+
+    if (mmp == nullptr) {
+        // Fall back to f16 dequant mul mat
+        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
+        quantize_y = false;
+    }
 
     const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
-    const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
+    const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
 
     if (qx_needs_dequant) {
         // Fall back to dequant + f16 mulmat
@@ -4278,13 +4432,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     // Not implemented
     GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 
-    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
-    const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
+    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
+    const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
 
-    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
+    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
 
     // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
-    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
+    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
     const int x_ne = ne01 * ne00;
     const int y_ne = padded_n * ne10;
     const int d_ne = ne11 * ne01;
@@ -4294,11 +4448,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
     const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
-    const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+    const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
     const uint64_t d_sz = sizeof(float) * d_ne;
 
     vk_pipeline to_fp16_vk_0 = nullptr;
     vk_pipeline to_fp16_vk_1 = nullptr;
+    vk_pipeline to_q8_1 = nullptr;
 
     if (x_non_contig) {
         to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
@@ -4313,6 +4468,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
     GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
 
+    if (quantize_y) {
+        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
+    }
+
     if (dryrun) {
         const uint64_t x_sz_upd = x_sz * ne02 * ne03;
         const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -4326,7 +4485,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
             ctx->prealloc_size_x = x_sz_upd;
         }
-        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
             ctx->prealloc_size_y = y_sz_upd;
         }
         if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
@@ -4341,6 +4500,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         if (qy_needs_dequant) {
             ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
         }
+        if (quantize_y) {
+            ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
+        }
         if (split_k > 1) {
             ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
         }
@@ -4376,6 +4538,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     if (qy_needs_dequant) {
         d_Y = ctx->prealloc_y;
         GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
+    } else if (quantize_y) {
+        d_Y = ctx->prealloc_y;
+        GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
     } else {
         d_Y = d_Qy;
         y_buf_offset = qy_buf_offset;
@@ -4392,6 +4557,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
     if (y_non_contig) {
         ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
     }
+    if (quantize_y) {
+        ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
+    }
 
     uint32_t stride_batch_x = ne00*ne01;
     uint32_t stride_batch_y = ne10*ne11;
@@ -4400,7 +4568,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
     }
 
-    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
+    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
         stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
     }
 
@@ -6929,6 +7097,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
         }
     }
 
+    if (ctx->device->need_compiles) {
+        ggml_vk_load_shaders(ctx->device);
+    }
+
     ggml_pipeline_allocate_descriptor_sets(ctx->device);
 
     vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -7177,6 +7349,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
 
     ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
 
+    if (ctx->device->need_compiles) {
+        ggml_vk_load_shaders(ctx->device);
+    }
+
     ggml_pipeline_allocate_descriptor_sets(ctx->device);
 
     ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
@@ -7236,66 +7412,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
     free(x_chk);
 }
 
-static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
+// This does not work without ggml q8_1 quantization support
+//
+// typedef uint16_t ggml_half;
+// typedef uint32_t ggml_half2;
+//
+// #define QK8_1 32
+// typedef struct {
+//     union {
+//         struct {
+//             ggml_half d; // delta
+//             ggml_half s; // d * sum(qs[i])
+//         } GGML_COMMON_AGGR_S;
+//         ggml_half2 ds;
+//     } GGML_COMMON_AGGR_U;
+//     int8_t qs[QK8_1]; // quants
+// } block_q8_1;
+//
+// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
+//     VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
+//     GGML_ASSERT(quant == GGML_TYPE_Q8_1);
+//
+//     const size_t x_sz = sizeof(float) * ne;
+//     const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
+//     float * x = (float *) malloc(x_sz);
+//     block_q8_1 * qx     = (block_q8_1 *)malloc(qx_sz);
+//     block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
+//     vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+//     vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+//
+//     for (size_t i = 0; i < ne; i++) {
+//         x[i] = rand() / (float)RAND_MAX;
+//     }
+//
+//     vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
+//
+//     ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
+//
+//     if (ctx->device->need_compiles) {
+//         ggml_vk_load_shaders(ctx->device);
+//     }
+//
+//     ggml_pipeline_allocate_descriptor_sets(ctx->device);
+//
+//     ggml_vk_buffer_write(x_buf, 0, x, x_sz);
+//
+//     vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
+//     ggml_vk_ctx_begin(ctx->device, subctx);
+//     ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
+//     ggml_vk_ctx_end(subctx);
+//
+//     auto begin = std::chrono::high_resolution_clock::now();
+//
+//     ggml_vk_submit(subctx, ctx->fence);
+//     VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
+//     ctx->device->device.resetFences({ ctx->fence });
+//
+//     auto end = std::chrono::high_resolution_clock::now();
+//
+//     double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
+//     ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
+//
+//     ggml_vk_quantize_data(x, qx_res, ne, quant);
+//
+//     int first_err = -1;
+//
+//     for (size_t i = 0; i < ne / 32; i++) {
+//         double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
+//
+//         if (first_err < 0 && error > 0.1) {
+//             first_err = i;
+//         }
+//
+//         error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
+//
+//         if (first_err < 0 && error > 0.1) {
+//             first_err = i;
+//         }
+//
+//         for (size_t j = 0; j < 32; j++) {
+//             uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
+//
+//             if (first_err < 0 && error > 1) {
+//                 first_err = i;
+//             }
+//         }
+//     }
+//
+//     std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
+//
+//     if (first_err != -1) {
+//         std::cerr << "first_error = " << first_err << std::endl;
+//         std::cerr << "Actual result: " << std::endl << std::endl;
+//         std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
+//         for (size_t j = 0; j < 32; j++) {
+//             std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
+//         }
+//         std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
+//         std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
+//         for (size_t j = 0; j < 32; j++) {
+//             std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
+//         }
+//         std::cerr << std::endl;
+//     }
+//
+//     ggml_vk_destroy_buffer(x_buf);
+//     ggml_vk_destroy_buffer(qx_buf);
+//
+//     free(x);
+//     free(qx);
+//     free(qx_res);
+// }
+
+static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
     VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
     const size_t x_ne = m * k * batch;
     const size_t y_ne = k * n * batch;
     const size_t d_ne = m * n * batch;
 
+    vk_matmul_pipeline2 * pipelines;
+
+    if (mmq) {
+        pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
+    } else {
+        pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
+    }
+
+    const bool fp16acc = ctx->device->fp16;
+
     vk_pipeline p;
     std::string shname;
     if (shader_size == 0) {
-        p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
+        p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
         shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
     } else if (shader_size == 1) {
-        p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
+        p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
         shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
     } else if (shader_size == 2) {
-        p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
+        p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
         shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
     } else {
         GGML_ASSERT(0);
     }
 
-    const size_t kpad = ggml_vk_align_size(k, p->align);
+    const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
 
-    if (k != kpad) {
+    if (mmq || k != kpad) {
         if (shader_size == 0) {
-            p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
+            p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
             shname = std::string(ggml_type_name(quant)) + "_S";
         } else if (shader_size == 1) {
-            p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
+            p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
             shname = std::string(ggml_type_name(quant)) + "_M";
         } else if (shader_size == 2) {
-            p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
+            p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
             shname = std::string(ggml_type_name(quant)) + "_L";
         } else {
             GGML_ASSERT(0);
         }
     }
 
+    if (p == nullptr) {
+        std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
+        return;
+    }
+
     const size_t x_sz = sizeof(float) * x_ne;
     const size_t y_sz = sizeof(float) * y_ne;
     const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
+    const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
     const size_t d_sz = sizeof(float) * d_ne;
     float * x = (float *) malloc(x_sz);
     float * y = (float *) malloc(y_sz);
     void * qx = malloc(qx_sz);
     vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
     vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
+    vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
     vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
     float * d = (float *) malloc(d_sz);
     float * d_chk = (float *) malloc(d_sz);
 
     for (size_t i = 0; i < x_ne; i++) {
         x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+        // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
+        // x[i] = i % k;
     }
 
     ggml_vk_quantize_data(x, qx, x_ne, quant);
 
     for (size_t i = 0; i < y_ne; i++) {
-        // y[i] = rand() / (float)RAND_MAX;
-        y[i] = (i % k == i / k) ? 1.0f : 0.0f;
+        y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
+        // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
+        // y[i] = i % k;
     }
 
     ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
@@ -7310,6 +7618,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
             ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
         }
     }
+    if (mmq) {
+        ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
+    }
+
+    if (ctx->device->need_compiles) {
+        ggml_vk_load_shaders(ctx->device);
+    }
 
     ggml_pipeline_allocate_descriptor_sets(ctx->device);
 
@@ -7318,13 +7633,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
 
     vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
     ggml_vk_ctx_begin(ctx->device, subctx);
-    for (size_t i = 0; i < num_it; i++) {
-        ggml_vk_matmul(
-            ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
-            m, n, k,
-            k, k, m, k*m, k*n, m*n,
-            split_k, batch, batch, batch, 1, 1, n
-        );
+    if (mmq) {
+        for (size_t i = 0; i < num_it; i++) {
+            ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
+            ggml_vk_matmul(
+                ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
+                m, n, k,
+                k, k, m, k*m, k*n, m*n,
+                split_k, batch, batch, batch, 1, 1, n
+            );
+        }
+    } else {
+        for (size_t i = 0; i < num_it; i++) {
+            ggml_vk_matmul(
+                ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
+                m, n, k,
+                k, k, m, k*m, k*n, m*n,
+                split_k, batch, batch, batch, 1, 1, n
+            );
+        }
     }
     ggml_vk_ctx_end(subctx);
 
@@ -7382,7 +7709,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
 
     double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
 
-    std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
+    std::cerr << "TEST dequant matmul " << shname;
+    if (mmq) {
+        std::cerr << " mmq";
+    }
+    std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
 
     if (avg_err > 0.01 || std::isnan(avg_err)) {
         std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
@@ -7392,6 +7723,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
         std::cerr << "Expected result: " << std::endl << std::endl;
         ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
 
+        std::cerr << "src0: " << std::endl << std::endl;
+        ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
+        std::cerr << std::endl;
+        std::cerr << "src1: " << std::endl << std::endl;
+        ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
+
         if (split_k > 1) {
             float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
             ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
@@ -7414,6 +7751,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
 
     ggml_vk_destroy_buffer(qx_buf);
     ggml_vk_destroy_buffer(y_buf);
+    ggml_vk_destroy_buffer(qy_buf);
     ggml_vk_destroy_buffer(d_buf);
 
     free(x);
@@ -7446,7 +7784,25 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
         128, 49, 49,
         4096, 49, 4096,
     };
-    const size_t num_it = 100;
+    const size_t num_it = 1;
+
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
+
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
+
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
+
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
+    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
+
+    abort();
 
     for (size_t i = 0; i < vals.size(); i += 3) {
         ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
@@ -9258,7 +9614,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
     }
 
     if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
-        const float *params = (const float *)tensor->op_params;
+        const float * params = (const float *)tensor->op_params;
         tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
     } else if (tensor->op == GGML_OP_MUL_MAT) {
         tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
@@ -9275,7 +9631,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
     } else if (tensor->op == GGML_OP_UPSCALE) {
         tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
     } else if (tensor->op == GGML_OP_SCALE) {
-        tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
+        const float * params = (const float *)tensor->op_params;
+        tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
     } else if (tensor->op == GGML_OP_SQR) {
         tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
     } else if (tensor->op == GGML_OP_SIN) {
@@ -9283,7 +9640,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
     } else if (tensor->op == GGML_OP_COS) {
         tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
     } else if (tensor->op == GGML_OP_CLAMP) {
-        tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
+        const float * params = (const float *)tensor->op_params;
+        tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
     } else if (tensor->op == GGML_OP_PAD) {
         tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
     } else if (tensor->op == GGML_OP_REPEAT) {
@@ -9297,7 +9655,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
     } else if (tensor->op == GGML_OP_NORM) {
         tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
     } else if (tensor->op == GGML_OP_GROUP_NORM) {
-        tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
+        const float * float_params = (const float *)tensor->op_params;
+        tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
     } else if (tensor->op == GGML_OP_RMS_NORM) {
         tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
     } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
@@ -9310,14 +9669,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
     } else if (tensor->op == GGML_OP_SOFT_MAX) {
         if (src1 != nullptr) {
-            tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
+            const float * params = (const float *)tensor->op_params;
+            tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
         } else {
             tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
         }
     } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
         tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
     } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
-        tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
+        tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
     } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
         const int n_dims      = ((int32_t *) tensor->op_params)[1];
         const int mode        = ((int32_t *) tensor->op_params)[2];
index 5a0054bac336c154949b494934e821921ce24f41..23ce8ceec332bc8f401a664233a844a8d4ab86ef 100644 (file)
@@ -212,7 +212,7 @@ void main() {
 #else
     ACC_TYPE sums[WMITER * TM * WNITER * TN];
     FLOAT_TYPE cache_a[WMITER * TM];
-    FLOAT_TYPE cache_b[WNITER * TN];
+    FLOAT_TYPE cache_b[TN];
 
     [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
         sums[i] = ACC_TYPE(0.0f);
@@ -744,16 +744,14 @@ void main() {
             }
             [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
                 [[unroll]] for (uint j = 0; j < TN; j++) {
-                    cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
+                    cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
                 }
-            }
 
-            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
                 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
                     [[unroll]] for (uint cc = 0; cc < TN; cc++) {
                         [[unroll]] for (uint cr = 0; cr < TM; cr++) {
                             const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
-                            sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
+                            sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
                         }
                     }
                 }
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
new file mode 100644 (file)
index 0000000..42f8135
--- /dev/null
@@ -0,0 +1,444 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+
+#extension GL_EXT_integer_dot_product : require
+
+#ifdef FLOAT16
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#endif
+
+#ifdef COOPMAT
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+#endif
+
+#ifdef MUL_MAT_ID
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#endif
+
+#include "types.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
+#if defined(A_TYPE_PACKED32)
+layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
+#endif
+layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+#ifdef MUL_MAT_ID
+layout (binding = 3) readonly buffer IDS {int data_ids[];};
+#endif
+
+layout (push_constant) uniform parameter
+{
+    uint M;
+    uint N;
+    uint K;
+    uint stride_a;
+    uint stride_b;
+    uint stride_d;
+
+    uint batch_stride_a;
+    uint batch_stride_b;
+    uint batch_stride_d;
+
+#ifdef MUL_MAT_ID
+    uint nei0;
+    uint nei1;
+    uint nbi1;
+    uint ne11;
+#else
+    uint k_split;
+    uint ne02;
+    uint ne12;
+    uint broadcast2;
+    uint broadcast3;
+#endif
+} p;
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 64;
+layout (constant_id = 1) const uint BM = 64;
+layout (constant_id = 2) const uint BN = 64;
+// layout (constant_id = 3) const uint BK = 32;
+layout (constant_id = 4) const uint WM = 32;
+layout (constant_id = 5) const uint WN = 32;
+layout (constant_id = 6) const uint WMITER = 2;
+layout (constant_id = 7) const uint TM = 4;
+layout (constant_id = 8) const uint TN = 2;
+layout (constant_id = 9) const uint TK = 1;  // Only needed for coopmat
+layout (constant_id = 10) const uint WARP = 32;
+
+#define BK 32
+
+#ifdef COOPMAT
+#define SHMEM_STRIDE (BK / 4 + 4)
+#else
+#define SHMEM_STRIDE (BK / 4 + 1)
+#endif
+
+shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
+
+#ifndef COOPMAT
+#if QUANT_AUXF == 1
+shared FLOAT_TYPE buf_a_dm[BM];
+#else
+shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
+#endif
+#endif
+
+shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
+#ifndef COOPMAT
+shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
+#endif
+
+#define LOAD_VEC_A (4 * QUANT_R)
+#define LOAD_VEC_B 4
+
+#ifdef MUL_MAT_ID
+shared u16vec2 row_ids[3072];
+#endif // MUL_MAT_ID
+
+#define NUM_WARPS (BLOCK_SIZE / WARP)
+
+#ifdef COOPMAT
+shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
+#endif
+
+#include "mul_mmq_funcs.comp"
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+    init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+#ifdef MUL_MAT_ID
+    const uint expert_idx = gl_GlobalInvocationID.z;
+#else
+    const uint batch_idx = gl_GlobalInvocationID.z;
+
+    const uint i13 = batch_idx / p.ne12;
+    const uint i12 = batch_idx % p.ne12;
+
+    const uint i03 = i13 / p.broadcast3;
+    const uint i02 = i12 / p.broadcast2;
+
+    const uint batch_idx_a = i03 * p.ne02 + i02;
+#endif
+
+    const uint blocks_m = (p.M + BM - 1) / BM;
+    const uint ir = gl_WorkGroupID.x % blocks_m;
+    const uint ik = gl_WorkGroupID.x / blocks_m;
+    const uint ic = gl_WorkGroupID.y;
+
+    const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
+    const uint WSUBM = WM / WMITER;
+    const uint WSUBN = WN / WNITER;
+
+#ifdef COOPMAT
+    const uint warp_i = gl_SubgroupID;
+
+    const uint tiw = gl_SubgroupInvocationID;
+
+    const uint cms_per_row = WM / TM;
+    const uint cms_per_col = WN / TN;
+
+    const uint storestride = WARP / TM;
+    const uint store_r = tiw % TM;
+    const uint store_c = tiw / TM;
+#else
+    const uint warp_i = gl_LocalInvocationID.x / WARP;
+
+    const uint tiw = gl_LocalInvocationID.x % WARP;
+
+    const uint tiwr = tiw % (WSUBM / TM);
+    const uint tiwc = tiw / (WSUBM / TM);
+#endif
+
+    const uint warp_r = warp_i % (BM / WM);
+    const uint warp_c = warp_i / (BM / WM);
+
+    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
+    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
+    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
+    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
+
+    const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
+    const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
+
+#ifdef MUL_MAT_ID
+    uint _ne1 = 0;
+    for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
+        for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+            if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
+                row_ids[_ne1] = u16vec2(ii0, ii1);
+                _ne1++;
+            }
+        }
+    }
+
+    barrier();
+
+    // Workgroup has no work
+    if (ic * BN >= _ne1) return;
+#endif
+
+#ifdef MUL_MAT_ID
+    const uint start_k = 0;
+    const uint end_k = p.K;
+#else
+    const uint start_k = ik * p.k_split;
+    const uint end_k = min(p.K, (ik + 1) * p.k_split);
+#endif
+
+    uint pos_a_ib = (
+#ifdef MUL_MAT_ID
+        expert_idx * p.batch_stride_a +
+#else
+        batch_idx_a * p.batch_stride_a +
+#endif
+        ir * BM * p.stride_a + start_k) / BK;
+#ifdef MUL_MAT_ID
+    uint pos_b_ib = 0;
+#else
+    uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
+#endif
+
+#ifdef COOPMAT
+    coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
+    coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
+    coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
+
+    coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
+
+    coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
+
+    [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
+        sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
+    }
+#else
+    int32_t cache_a_qs[WMITER * TM * BK / 4];
+
+    int32_t cache_b_qs[TN * BK / 4];
+
+    ACC_TYPE sums[WMITER * TM * WNITER * TN];
+
+    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
+        sums[i] = ACC_TYPE(0.0f);
+    }
+#endif
+
+#if QUANT_AUXF == 1
+    FLOAT_TYPE cache_a_dm[TM];
+#else
+    FLOAT_TYPE_VEC2 cache_a_dm[TM];
+#endif
+
+    FLOAT_TYPE_VEC2 cache_b_ds[TN];
+
+    for (uint block = start_k; block < end_k; block += BK) {
+        [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
+            const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
+            const uint iqs = loadr_a;
+            const uint buf_ib = loadc_a + l;
+
+            // Should ds be gated to a single thread?
+            if (iqs == 0) {
+#if QUANT_AUXF == 1
+                buf_a_dm[buf_ib] = get_d(ib);
+#else
+                buf_a_dm[buf_ib] = get_dm(ib);
+#endif
+            }
+#if QUANT_R == 1
+            buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
+#else
+            const i32vec2 vals = repack(ib, iqs);
+            buf_a_qs[buf_ib * SHMEM_STRIDE + iqs    ] = vals.x;
+            buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
+#endif
+        }
+        [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
+#ifdef MUL_MAT_ID
+            const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
+            const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
+            const uint ib = idx / 8;
+            const uint iqs = idx & 0x7;
+#else
+            const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
+            const uint iqs = loadr_b;
+#endif
+
+            const uint buf_ib = loadc_b + l;
+
+            // Should ds be gated to a single thread?
+            if (iqs == 0) {
+                buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
+            }
+            buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
+        }
+
+        barrier();
+
+        pos_a_ib += 1;
+        pos_b_ib += 1;
+
+#ifdef COOPMAT
+        [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+            const uint ib_a = warp_r * WM + cm_row * TM;
+            // Load from shared into cache
+            coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
+
+            // TODO: only cache values that are actually needed
+            [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
+                cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
+            }
+
+            [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+                const uint ib_b = warp_c * WN + cm_col * TN;
+                coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
+
+                // TODO: only cache values that are actually needed
+                [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
+                    cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
+                }
+
+                cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
+                cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
+
+                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+                    coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
+                }
+
+                coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+                sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
+            }
+        }
+#else
+        // Load from shared into cache
+        [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+            [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+                const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
+                cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
+                [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
+                    cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
+                }
+            }
+        }
+
+        [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+            [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+                const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
+                cache_b_ds[cc] = buf_b_ds[ib];
+                [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
+                    cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
+                }
+            }
+
+            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+                [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+                    [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+                        const uint cache_a_idx = wsir * TM + cr;
+                        const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+                        int32_t q_sum = 0;
+                        [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
+                            q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
+                                                     cache_b_qs[cc * (BK / 4) + idx_k]);
+                        }
+
+                        sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
+                    }
+                }
+            }
+        }
+#endif
+
+        barrier();
+    }
+
+    const uint dr = ir * BM + warp_r * WM;
+    const uint dc = ic * BN + warp_c * WN;
+
+#ifndef MUL_MAT_ID
+    const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
+#endif
+
+#ifdef COOPMAT
+#ifdef MUL_MAT_ID
+    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+            coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+            [[unroll]] for (uint col = 0; col < BN; col += storestride) {
+                const uint row_i = dc + cm_col * TN + col + store_c;
+                if (row_i >= _ne1) break;
+
+                const u16vec2 row_idx = row_ids[row_i];
+
+                data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+            }
+        }
+    }
+#else
+    const bool is_aligned = p.stride_d % 4 == 0;  // Assumption: D_TYPE == float
+
+    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
+        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
+            const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
+
+            if (is_aligned && is_in_bounds) {
+                // Full coopMat is within bounds and stride_d is aligned with 16B
+                coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
+                coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
+            } else if (is_in_bounds) {
+                // Full coopMat is within bounds, but stride_d is not aligned
+                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+                    data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+                }
+            } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
+                // Partial coopMat is within bounds
+                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
+
+                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
+                    if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
+                        data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+                    }
+                }
+            }
+        }
+    }
+#endif // MUL_MAT_ID
+#else
+    [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+        [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+
+            const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
+            const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
+            [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+#ifdef MUL_MAT_ID
+                const uint row_i = dc_warp + cc;
+                if (row_i >= _ne1) break;
+
+                const u16vec2 row_idx = row_ids[row_i];
+#endif // MUL_MAT_ID
+                [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+#ifdef MUL_MAT_ID
+                    data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+#else
+                    if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
+                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+                    }
+#endif // MUL_MAT_ID
+                }
+            }
+        }
+    }
+#endif // COOPMAT
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp
new file mode 100644 (file)
index 0000000..c4c35e1
--- /dev/null
@@ -0,0 +1,99 @@
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+
+#include "types.comp"
+
+// Each iqs value maps to a 32-bit integer
+
+#if defined(DATA_A_Q4_0)
+i32vec2 repack(uint ib, uint iqs) {
+    // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
+    const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2    ],
+                                   data_a[ib].qs[iqs * 2 + 1]);
+    const uint32_t vui = pack32(quants);
+    return i32vec2( vui       & 0x0F0F0F0F,
+                   (vui >> 4) & 0x0F0F0F0F);
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
+    return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0 * dsb.y));
+}
+#endif
+
+#if defined(DATA_A_Q4_1)
+i32vec2 repack(uint ib, uint iqs) {
+    // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
+    const uint32_t vui = data_a_packed32[ib].qs[iqs];
+    return i32vec2( vui       & 0x0F0F0F0F,
+                   (vui >> 4) & 0x0F0F0F0F);
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
+    return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
+}
+#endif
+
+#if defined(DATA_A_Q5_0)
+i32vec2 repack(uint ib, uint iqs) {
+    // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
+    const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2    ],
+                                   data_a[ib].qs[iqs * 2 + 1]);
+    const uint32_t vui = pack32(quants);
+    const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
+    const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
+                     | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+
+    const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
+                     | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+
+    return i32vec2(v0, v1);
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
+    return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0 * dsb.y));
+}
+#endif
+
+#if defined(DATA_A_Q5_1)
+i32vec2 repack(uint ib, uint iqs) {
+    // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
+    const uint32_t vui = data_a_packed32[ib].qs[iqs];
+    const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
+    const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
+                     | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+
+    const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
+                     | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+
+    return i32vec2(v0, v1);
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
+    return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
+}
+#endif
+
+#if defined(DATA_A_Q8_0)
+int32_t repack(uint ib, uint iqs) {
+    // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
+    return pack32(i16vec2(data_a[ib].qs[iqs * 2    ],
+                          data_a[ib].qs[iqs * 2 + 1]));
+}
+
+ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
+    return ACC_TYPE(float(q_sum) * da * dsb.x);
+}
+#endif
+
+#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
+FLOAT_TYPE get_d(uint ib) {
+    return FLOAT_TYPE(data_a[ib].d);
+}
+#endif
+
+#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
+FLOAT_TYPE_VEC2 get_dm(uint ib) {
+    return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+}
+#endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp
new file mode 100644 (file)
index 0000000..e2e020f
--- /dev/null
@@ -0,0 +1,77 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+#extension GL_EXT_shader_16bit_storage : require
+
+layout (push_constant) uniform parameter
+{
+    uint ne;
+} p;
+
+#include "types.comp"
+
+layout(constant_id = 0) const uint GROUP_SIZE = 32;
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {vec4 data_a[];};
+layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
+
+shared float shmem[GROUP_SIZE];
+
+void quantize() {
+    const uint wgid = gl_WorkGroupID.x;
+    const uint tid = gl_LocalInvocationID.x;
+
+    // Each thread handles a vec4, so 8 threads handle a block
+    const uint blocks_per_group = GROUP_SIZE / 8;
+
+    const uint block_in_wg = tid / 8;
+
+    const uint ib = wgid * blocks_per_group + block_in_wg;
+    const uint iqs = tid % 8;
+
+    if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
+        return;
+    }
+
+    const uint a_idx = ib * 8 + iqs;
+
+    vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
+    const vec4 abs_vals = abs(vals);
+
+    // Find absolute max for each block
+    shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
+    barrier();
+    [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
+        if (iqs < s) {
+            shmem[tid] = max(shmem[tid], shmem[tid + s]);
+        }
+        barrier();
+    }
+
+    const float amax = shmem[block_in_wg * 8];
+    const float d = amax / 127.0;
+    const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
+    vals = round(vals * d_inv);
+    data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
+    barrier();
+
+    // Calculate the sum for each block
+    shmem[tid] = vals.x + vals.y + vals.z + vals.w;
+    barrier();
+    [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
+        if (iqs < s) {
+            shmem[tid] += shmem[tid + s];
+        }
+        barrier();
+    }
+    if (iqs == 0) {
+        const float sum = shmem[tid];
+
+        data_b[ib].ds = f16vec2(vec2(d, sum * d));
+    }
+}
+
+void main() {
+    quantize();
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp
new file mode 100644 (file)
index 0000000..470e307
--- /dev/null
@@ -0,0 +1,7 @@
+#version 460
+
+#extension GL_EXT_integer_dot_product : require
+
+void main()
+{
+}
index 789776816b75a60e03d034f4a2ade45623dcf05a..f5b29bfb13a6690429a60ae26cf022c2e8d3b4fa 100644 (file)
@@ -1,4 +1,3 @@
-
 #if !defined(GGML_TYPES_COMP)
 #define GGML_TYPES_COMP
 
@@ -51,6 +50,7 @@ struct block_q4_0_packed16
 #if defined(DATA_A_Q4_0)
 #define QUANT_K QUANT_K_Q4_0
 #define QUANT_R QUANT_R_Q4_0
+#define QUANT_AUXF 1
 #define A_TYPE block_q4_0
 #define A_TYPE_PACKED16 block_q4_0_packed16
 #endif
@@ -72,11 +72,19 @@ struct block_q4_1_packed16
     uint16_t qs[16/2];
 };
 
+struct block_q4_1_packed32
+{
+    f16vec2 dm;
+    uint32_t qs[16/4];
+};
+
 #if defined(DATA_A_Q4_1)
 #define QUANT_K QUANT_K_Q4_1
 #define QUANT_R QUANT_R_Q4_1
+#define QUANT_AUXF 2
 #define A_TYPE block_q4_1
 #define A_TYPE_PACKED16 block_q4_1_packed16
+#define A_TYPE_PACKED32 block_q4_1_packed32
 #endif
 
 #define QUANT_K_Q5_0 32
@@ -99,6 +107,7 @@ struct block_q5_0_packed16
 #if defined(DATA_A_Q5_0)
 #define QUANT_K QUANT_K_Q5_0
 #define QUANT_R QUANT_R_Q5_0
+#define QUANT_AUXF 1
 #define A_TYPE block_q5_0
 #define A_TYPE_PACKED16 block_q5_0_packed16
 #endif
@@ -122,11 +131,20 @@ struct block_q5_1_packed16
     uint16_t qs[16/2];
 };
 
+struct block_q5_1_packed32
+{
+    f16vec2 dm;
+    uint qh;
+    uint32_t qs[16/4];
+};
+
 #if defined(DATA_A_Q5_1)
 #define QUANT_K QUANT_K_Q5_1
 #define QUANT_R QUANT_R_Q5_1
+#define QUANT_AUXF 2
 #define A_TYPE block_q5_1
 #define A_TYPE_PACKED16 block_q5_1_packed16
+#define A_TYPE_PACKED32 block_q5_1_packed32
 #endif
 
 #define QUANT_K_Q8_0 32
@@ -142,14 +160,40 @@ struct block_q8_0_packed16
     float16_t d;
     int16_t qs[32/2];
 };
+struct block_q8_0_packed32
+{
+    float16_t d;
+    int32_t qs[32/4];
+};
 
 #if defined(DATA_A_Q8_0)
 #define QUANT_K QUANT_K_Q8_0
 #define QUANT_R QUANT_R_Q8_0
+#define QUANT_AUXF 1
 #define A_TYPE block_q8_0
 #define A_TYPE_PACKED16 block_q8_0_packed16
+#define A_TYPE_PACKED32 block_q8_0_packed32
 #endif
 
+#define QUANT_K_Q8_1 32
+#define QUANT_R_Q8_1 1
+
+struct block_q8_1
+{
+    f16vec2 ds;
+    int8_t qs[32];
+};
+struct block_q8_1_packed16
+{
+    f16vec2 ds;
+    int16_t qs[16];
+};
+struct block_q8_1_packed32
+{
+    f16vec2 ds;
+    int32_t qs[8];
+};
+
 // K-quants
 #define QUANT_K_Q2_K 256
 
index 1edb8267f1ebef7bf671dda9f82a63f0cd10aa0e..2ac4caee70e17900f8f8ca17c6531ea0589079d0 100644 (file)
@@ -295,7 +295,10 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
     std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
     std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
 
-    std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
+    std::map<std::string, std::string> base_dict = {
+        {"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"},
+        {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
+    };
     std::string shader_name = "matmul";
 
     if (matmul_id) {
@@ -313,9 +316,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
         base_dict["COOPMAT"] = "1";
     }
 
-    base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
-
-    std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
+    const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
 
     // Shaders with f16 B_TYPE
     string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
@@ -339,14 +340,20 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
 
         // don't generate f32 variants for coopmat2
         if (!coopmat2) {
-            string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
-            string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f32",         source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float"},            {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
         }
 
         if (tname != "f16" && tname != "f32") {
-            string_to_spv(shader_name + "_" + tname + "_f16", source_name,          merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
-            string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f16",         source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
         }
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+        if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
+            string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
+        }
+#endif
     }
 }
 
@@ -458,6 +465,7 @@ void process_shaders() {
     string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
     string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
+    string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
 
     string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});