]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536)
authorRuben Ortlam <redacted>
Wed, 29 Oct 2025 13:39:03 +0000 (14:39 +0100)
committerGitHub <redacted>
Wed, 29 Oct 2025 13:39:03 +0000 (14:39 +0100)
* vulkan: add mmq q2_k integer dot support

* Refactor mmq caching

* Reduce mmq register use

* Load 4 quant blocks into shared memory in one step

* Pack q2_k blocks into caches of 32

* Use 32-bit accumulators for integer dot matmul

* Add q4_k mmq

* Add q3_k mmq

* Add q5_k mmq

* Add q6_k mmq

* Add mxfp4 mmq, enable MMQ MUL_MAT_ID

* Fix mmv dm loads

18 files changed:
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 5caf37d4030ad1568061381b244d43e178b5693d..3d10aa07b0885c9a55d248f21d11df3df4748d9a 100644 (file)
@@ -486,6 +486,7 @@ struct vk_device_struct {
     vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
 
     vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
+    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
 
     vk_pipeline pipeline_matmul_split_k_reduce;
     vk_pipeline pipeline_quantize_q8_1;
@@ -2448,8 +2449,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
                           l_warptile_id, m_warptile_id, s_warptile_id,
                           l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
                           l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
+                          l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
                           l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
-                          l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
+                          l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
+                          l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
+                          l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
     std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
                             l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
                             l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
@@ -2512,10 +2516,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
         m_warptile_mmq = { 128,  64,  64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
         s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
 
+        // Integer MMQ has a smaller shared memory profile, but heavier register use
         l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
         m_warptile_mmq_int = { 128,  64,  64, 32, subgroup_size_8,     32, 2, 2, 2, 1, subgroup_size_8 };
         s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32,       32, 2, 2, 1, 1, subgroup_size_8 };
 
+        // K-quants use even more registers, mitigate by setting WMITER to 1
+        l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
+        m_warptile_mmq_int_k = { 128,  64,  64, 32, subgroup_size_8,     32, 1, 2, 2, 1, subgroup_size_8 };
+        s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32,       32, 1, 2, 1, 1, subgroup_size_8 };
+
         l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
         m_warptile_id = { 128,  64,  64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
         s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
@@ -2524,10 +2534,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
         m_warptile_mmqid = { 128,  64,  64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
         s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
 
+        l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
+        m_warptile_mmqid_int = { 128,  64,  64, 32, mul_mat_subgroup_size_8,     32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
+        s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32,       32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
+
+        l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
+        m_warptile_mmqid_int_k = { 128,  64,  64, 32, mul_mat_subgroup_size_16,     32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
+        s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32,       32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
+
         // chip specific tuning
         if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
             m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
-            m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
+            m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
         }
 
         l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@@ -2912,18 +2930,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
         if (device->mul_mat ## ID ## _s[TYPE]) \
             ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 
-#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
         if (device->mul_mat ## ID ## _l[TYPE]) { \
-            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ##  _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
-            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC        "_l", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC        "_l", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
         } \
         if (device->mul_mat ## ID ## _m[TYPE]) { \
-            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ##  _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
-            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC        "_m", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC        "_m", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
         } \
         if (device->mul_mat ## ID ## _s[TYPE]) { \
-            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ##  _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
-            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC        "_s", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
+            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC        "_s", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
         } \
 
         // Create 2 variants, {f16,f32} accumulator
@@ -2962,11 +2977,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
 #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
         if (device->integer_dot_product) {
-            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
-            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
-            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
-            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
-            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+
+            CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+
+            CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+            CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+            CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+            CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+            CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
         }
 #endif
 
@@ -2996,6 +3019,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
             CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
             CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+            if (device->integer_dot_product) {
+                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+            }
+#endif
         } else {
             CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
             CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
@@ -3022,6 +3063,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
             CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
             CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+            if (device->integer_dot_product) {
+                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
+
+                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, 4, _id, 0);
+
+                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+            }
+#endif
         }
 #undef CREATE_MM2
 #undef CREATE_MMQ
@@ -3086,6 +3145,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
             CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
             CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+
+            CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+            CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
         }
 #endif
 
@@ -3145,7 +3210,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     }
     // reusing CREATE_MM from the fp32 path
     if ((device->coopmat2 || device->coopmat_support)
-#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
         && !device->coopmat_bf16_support
 #endif
         ) {
@@ -4928,7 +4993,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
 
     // MMQ
     if (src1_type == GGML_TYPE_Q8_1) {
-        vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
+        vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
 
         if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
             return nullptr;
@@ -5075,6 +5140,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
         }
     }
 
+    // MMQ
+    if (src1_type == GGML_TYPE_Q8_1) {
+        vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
+
+        if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+            return nullptr;
+        }
+
+        return pipelines;
+    }
+
     GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
 
     switch (src0_type) {
@@ -6877,10 +6953,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
 
     const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 
-    vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
+    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
+
+    // Check for mmq first
+    vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
+
+    if (mmp == nullptr) {
+        // Fall back to f16 dequant mul mat
+        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
+        quantize_y = false;
+    }
 
     const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
-    const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
+    const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
 
     if (qx_needs_dequant) {
         // Fall back to dequant + f16 mulmat
@@ -6890,8 +6975,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     // Not implemented
     GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 
-    const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
-    const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
+    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
+    const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
 
     vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
 
@@ -6904,12 +6989,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
     const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
     const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
-    const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+    const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
     const uint64_t ids_sz = nbi2;
     const uint64_t d_sz = sizeof(float) * d_ne;
 
     vk_pipeline to_fp16_vk_0 = nullptr;
     vk_pipeline to_fp16_vk_1 = nullptr;
+    vk_pipeline to_q8_1 = nullptr;
 
     if (x_non_contig) {
         to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
@@ -6924,9 +7010,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
     GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
 
+    if (quantize_y) {
+        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
+    }
+
     if (dryrun) {
         const uint64_t x_sz_upd = x_sz * ne02 * ne03;
-        const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+        uint64_t y_sz_upd = y_sz * ne12 * ne13;
+        if (quantize_y) {
+            y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
+        }
         if (
                 (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
                 (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
@@ -6935,7 +7028,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
         if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
             ctx->prealloc_size_x = x_sz_upd;
         }
-        if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
             ctx->prealloc_size_y = y_sz_upd;
         }
 
@@ -6947,6 +7040,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
         if (qy_needs_dequant) {
             ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
         }
+        if (quantize_y) {
+            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
+        }
         return;
     }
 
@@ -6983,6 +7079,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     if (qy_needs_dequant) {
         d_Y = ctx->prealloc_y;
         GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
+    } else if (quantize_y) {
+        d_Y = ctx->prealloc_y;
+        GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
     } else {
         d_Y = d_Qy;
         y_buf_offset = qy_buf_offset;
@@ -7014,6 +7113,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
             ctx->prealloc_y_last_tensor_used = src1;
         }
     }
+    if (quantize_y) {
+        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
+            ctx->prealloc_y_last_tensor_used != src1) {
+            if (ctx->prealloc_y_need_sync) {
+                ggml_vk_sync_buffers(ctx, subctx);
+            }
+            ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
+            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
+            ctx->prealloc_y_last_tensor_used = src1;
+        }
+    }
 
     uint32_t stride_batch_x = ne00*ne01;
     uint32_t stride_batch_y = ne10*ne11;
@@ -7022,14 +7132,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
         stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
     }
 
-    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
+    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
         stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
     }
 
+    uint32_t y_sz_total = y_sz * ne12 * ne13;
+    if (quantize_y) {
+        y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
+    }
+
     // compute
     ggml_vk_matmul_id(
         ctx, subctx, pipeline,
-        { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
+        { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
         { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
         ne01, ne21, ne10, ne10, ne10, ne01,
         stride_batch_x, stride_batch_y, ne20*ne21,
index 0d98f5a9d6bf17704542b57f339903a7f6e41dfc..09676a623ba632c4a092c15f2ba5598b5f36685a 100644 (file)
@@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
 #if defined(DATA_A_MXFP4)
 vec2 dequantize(uint ib, uint iqs, uint a_offset) {
     const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
-    return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
+    return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
 }
 vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
     vec2 v0 = dequantize(ib, iqs, a_offset);
@@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
 
     const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
     const uint scales = data_a[a_offset + ib].scales[scalesi];
-    const vec2 d = vec2(data_a[a_offset + ib].d);
+    const vec2 dm = vec2(data_a[a_offset + ib].dm);
 
-    return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
+    return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
 }
 vec2 get_dm(uint ib, uint a_offset) {
     return vec2(1, 0);
@@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
     const uint is = 2 * n + b;                 // 0..7
     const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126
 
-    const vec2 loadd = vec2(data_a[a_offset + ib].d);
+    const vec2 loadd = vec2(data_a[a_offset + ib].dm);
 
     const uint scidx0 = (is < 4) ? is : (is + 4);
     const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
 
     const uint8_t hm = uint8_t(1 << (iqs / 16));
 
-    const vec2 loadd = vec2(data_a[a_offset + ib].d);
+    const vec2 loadd = vec2(data_a[a_offset + ib].dm);
 
     const uint scidx0 = (is < 4) ? is : (is + 4);
     const uint scidx1 = (is < 4) ? is : (is - 4);
index 67baedf7c6147be1f9b89a04d64380412f7d826f..8ac6482dc944b83b1af7c4df9da1b8110698d741 100644 (file)
@@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2
 float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
 {
     decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
-    const f16vec2 d = bl.block.d;
+    const f16vec2 dm = bl.block.dm;
     const uint idx = coordInBlock[1];
 
     const uint scalesi = (idx & 0xF0) >> 4;             // 0..15
@@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2
     qs = unpack8(qs)[idx & 1];
 
     const uint scales = bl.block.scales[scalesi];
-    float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
+    float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
     return ret;
 }
 
@@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
     uint32_t qs = bl.block.qs[iqs];
     qs >>= shift;
     qs &= 0xF;
-    float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
+    float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
     return ret;
 }
 #endif
index ffba5a77ddf53a9b84d3c5d62156def0c2bc72e8..3194ba291f311bd6b9de728f2ce121ff6b0d000b 100644 (file)
@@ -26,7 +26,7 @@ void main() {
     const float d = e8m0_to_fp32(data_a[ib].e);
 
     [[unroll]] for (uint l = 0; l < 8; ++l) {
-        data_b[b_idx + l +  0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
-        data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >>  4]);
+        data_b[b_idx + l +  0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
+        data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >>  4]));
     }
 }
index 58dc2e5dfde9d3567d0111995d0f1fb59db6b61c..dc05a78348909723e71282b242eaa8de7de69741 100644 (file)
@@ -24,8 +24,8 @@ void main() {
         const uint ql_idx = 32 * ip + il;
         const uint8_t qs = data_a[i].qs[32 * ip + il];
 
-        FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
-        FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
+        FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
+        FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
         data_b[y_idx +  0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
         data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
         data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
index 8b7be557e95484281919aa23f7a4208f6d79c7b9..0f23dc0a349f62618c389f278cab8a855ce21bfe 100644 (file)
@@ -20,8 +20,8 @@ void main() {
         const uint is = 2 * il;
         const uint n = 4;
 
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
+        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
+        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
 
         const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
         const uint qs_idx = 32*il + n * ir;
index 6bc04670fc5931b74b7474d457901e204ea99775..970469a601cc654dcf848efa86638c675c530ada 100644 (file)
@@ -19,8 +19,8 @@ void main() {
         const uint ir = tid % 16;
         const uint is = 2 * il;
 
-        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
+        const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
+        const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
 
         const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
         const uint qs_idx = 32*il + 2 * ir;
index 03ed25d3bfe4e5ee50313472b109fe2cbfbc6cba..14093c0de5a4593ead242a14015d3400efa11725 100644 (file)
@@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
         const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
         const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
 
-        vec2 d = vec2(data_a[ib0 + i].d);
-        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+        const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
 
         [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
             vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);
@@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
                        fma(FLOAT_TYPE(b96[l]),  sccache2[csel][ix][6 + 8*v_im],
                        fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
             }
-            temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
+            temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
         }
     }
 }
index 21d07d2e5096439a8cd34c2f59cf8c960f8c64c8..49d91ad59101ef54601cc0133f4e56b37247e343 100644 (file)
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-        vec2 d = vec2(data_a[ib0 + i].d);
-        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+        const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
 
         const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
         const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
                 fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
                 fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
                 fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7)))))))))))))));
-            temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+            temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
         }
     }
 }
index 9e46c89a11f509344b6dfdcf110f1e56cd5d7e95..0d61b4966ec4a424fff402c2cca852faa3419b4a 100644 (file)
@@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-        vec2 d = vec2(data_a[ib0 + i].d);
-        const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
-        const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+        const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
 
         const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ];
         const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
@@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
               fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
               fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
                   (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
-            temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+            temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
         }
     }
 }
index a20788c4b51e34505dbdcc0a09e9d1cb4d435c6a..d260969f07e882137def1b1e9a227849f197d88e 100644 (file)
@@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
 
 #define NUM_WARPS (BLOCK_SIZE / WARP)
 
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[BN];
-uint _ne1;
-
-#ifdef MUL_MAT_ID_USE_SUBGROUPS
-shared uvec4 ballots_sh[NUM_WARPS];
-
-void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
-    _ne1 = 0;
-    uint num_elements = p.nei1 * p.nei0;
-    uint nei0shift = findLSB(p.nei0);
-
-    uint ids[16];
-    uint iter = 0;
-
-    for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
-        // prefetch up to 16 elements
-        if (iter == 0) {
-            [[unroll]] for (uint k = 0; k < 16; ++k) {
-                uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
-                bool in_range = i < num_elements;
-                uint ii1;
-                if (nei0_is_pow2) {
-                    ii1 = i >> nei0shift;
-                } else {
-                    ii1 = i / p.nei0;
-                }
-                uint ii0 = i - ii1 * p.nei0;
-                ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
-            }
-        }
-        uint i = j + gl_LocalInvocationIndex;
-        bool in_range = i < num_elements;
-        uint ii1;
-        if (nei0_is_pow2) {
-            ii1 = i >> nei0shift;
-        } else {
-            ii1 = i / p.nei0;
-        }
-        uint ii0 = i - ii1 * p.nei0;
-        uint id = ids[iter++];
-        uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
-
-        ballots_sh[gl_SubgroupID] = ballot;
-        barrier();
-
-        uint subgroup_base = 0;
-        uint total = 0;
-        for (uint k = 0; k < gl_NumSubgroups; ++k) {
-            if (k == gl_SubgroupID) {
-                subgroup_base = total;
-            }
-            total += subgroupBallotBitCount(ballots_sh[k]);
-        }
-        barrier();
-
-        uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
-        if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
-            row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
-        }
-        _ne1 += total;
-        iter &= 15;
-        if (_ne1 >= (ic + 1) * BN) {
-            break;
-        }
-    }
-    barrier();
-}
-#endif // MUL_MAT_ID_USE_SUBGROUPS
-#endif // MUL_MAT_ID
-
 #ifdef COOPMAT
 shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
 #endif
 
+#include "mul_mm_id_funcs.glsl"
 #include "mul_mm_funcs.glsl"
 
 void main() {
index 0ebfbd6462c8b76897ddc59b597cb6fdc12f22b1..ee5ded2e8d3eb17ab33df8bfd874dfe9f57fed66 100644 (file)
@@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint ib = idx / 128;                         // 2 values per idx
             const uint iqs = idx % 128;                        // 0..127
 
-            const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
+            const uint qsi = (iqs / 64) * 16 + (iqs % 16);     // 0..15
             const uint scalesi = iqs / 8;                      // 0..15
             const uint qsshift = ((iqs % 64) / 16) * 2;        // 0,2,4,6
 
-            const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
+            const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
             const uint scales = data_a[ib].scales[scalesi];
-            const vec2 d = vec2(data_a[ib].d);
+            const vec2 dm = vec2(data_a[ib].dm);
 
-            const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
+            const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
 
             buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
 #elif defined(DATA_A_Q3_K)
@@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint is = 2 * n + b;                 // 0..7
             const uint qsi = n * 32 + (iqs % 16) * 2;  // 0,2,4..126
 
-            const vec2 loadd = vec2(data_a[ib].d);
+            const vec2 loadd = vec2(data_a[ib].dm);
 
             const uint scidx0 = (is < 4) ? is : (is + 4);
             const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
 
             const uint8_t hm = uint8_t(1 << (iqs / 16));
 
-            const vec2 loadd = vec2(data_a[ib].d);
+            const vec2 loadd = vec2(data_a[ib].dm);
 
             const uint scidx0 = (is < 4) ? is : (is + 4);
             const uint scidx1 = (is < 4) ? is : (is - 4);
@@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             const uint ib = idx / 8;
             const uint iqs = (idx & 0x07) * 2;
 
-            const float d = e8m0_to_fp32(data_a[ib].e);
+            const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
             const uint vui = uint(data_a[ib].qs[iqs]);
             const uint vui2 = uint(data_a[ib].qs[iqs+1]);
 
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
new file mode 100644 (file)
index 0000000..1d0e84a
--- /dev/null
@@ -0,0 +1,70 @@
+#ifdef MUL_MAT_ID
+shared u16vec2 row_ids[BN];
+uint _ne1;
+
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+shared uvec4 ballots_sh[NUM_WARPS];
+
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
+    _ne1 = 0;
+    uint num_elements = p.nei1 * p.nei0;
+    uint nei0shift = findLSB(p.nei0);
+
+    uint ids[16];
+    uint iter = 0;
+
+    for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+        // prefetch up to 16 elements
+        if (iter == 0) {
+            [[unroll]] for (uint k = 0; k < 16; ++k) {
+                uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+                bool in_range = i < num_elements;
+                uint ii1;
+                if (nei0_is_pow2) {
+                    ii1 = i >> nei0shift;
+                } else {
+                    ii1 = i / p.nei0;
+                }
+                uint ii0 = i - ii1 * p.nei0;
+                ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+            }
+        }
+        uint i = j + gl_LocalInvocationIndex;
+        bool in_range = i < num_elements;
+        uint ii1;
+        if (nei0_is_pow2) {
+            ii1 = i >> nei0shift;
+        } else {
+            ii1 = i / p.nei0;
+        }
+        uint ii0 = i - ii1 * p.nei0;
+        uint id = ids[iter++];
+        uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+        ballots_sh[gl_SubgroupID] = ballot;
+        barrier();
+
+        uint subgroup_base = 0;
+        uint total = 0;
+        for (uint k = 0; k < gl_NumSubgroups; ++k) {
+            if (k == gl_SubgroupID) {
+                subgroup_base = total;
+            }
+            total += subgroupBallotBitCount(ballots_sh[k]);
+        }
+        barrier();
+
+        uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+        if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+            row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
+        }
+        _ne1 += total;
+        iter &= 15;
+        if (_ne1 >= (ic + 1) * BN) {
+            break;
+        }
+    }
+    barrier();
+}
+#endif // MUL_MAT_ID_USE_SUBGROUPS
+#endif // MUL_MAT_ID
index b5d761c0bab9e3bcc1e55b7c1db656e49f14125c..8b238ac4bc117f68f2c85c32cf7b1c254bf53dbc 100644 (file)
 #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
 #endif
 
-#ifdef COOPMAT
-#extension GL_KHR_cooperative_matrix : enable
-#extension GL_KHR_memory_scope_semantics : enable
+#if defined(MUL_MAT_ID_USE_SUBGROUPS)
 #extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
 #endif
 
 #ifdef MUL_MAT_ID
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#if defined(A_TYPE_PACKED16)
+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
+#endif
 #if defined(A_TYPE_PACKED32)
 layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
 #endif
@@ -76,40 +78,27 @@ layout (constant_id = 10) const uint WARP = 32;
 
 #define BK 32
 
-#ifdef COOPMAT
-#define SHMEM_STRIDE (BK / 4 + 4)
-#else
-#define SHMEM_STRIDE (BK / 4 + 1)
-#endif
+#define MMQ_SHMEM
 
-shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
+#include "mul_mmq_shmem_types.glsl"
 
-#ifndef COOPMAT
-#if QUANT_AUXF == 1
-shared FLOAT_TYPE buf_a_dm[BM];
-#else
-shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
-#endif
+#ifndef BK_STEP
+#define BK_STEP 4
 #endif
 
-shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
-#ifndef COOPMAT
-shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
-#endif
+// Shared memory cache
+shared block_a_cache buf_a[BM * BK_STEP];
+shared block_b_cache buf_b[BN * BK_STEP];
+// Register cache
+block_a_cache cache_a[WMITER * TM];
+block_b_cache cache_b;
 
-#define LOAD_VEC_A (4 * QUANT_R)
+#define LOAD_VEC_A (4 * QUANT_R_MMQ)
 #define LOAD_VEC_B 16
 
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[4096];
-#endif // MUL_MAT_ID
-
 #define NUM_WARPS (BLOCK_SIZE / WARP)
 
-#ifdef COOPMAT
-shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
-#endif
-
+#include "mul_mm_id_funcs.glsl"
 #include "mul_mmq_funcs.glsl"
 
 void main() {
@@ -139,26 +128,12 @@ void main() {
     const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
     const uint WSUBM = WM / WMITER;
     const uint WSUBN = WN / WNITER;
-
-#ifdef COOPMAT
-    const uint warp_i = gl_SubgroupID;
-
-    const uint tiw = gl_SubgroupInvocationID;
-
-    const uint cms_per_row = WM / TM;
-    const uint cms_per_col = WN / TN;
-
-    const uint storestride = WARP / TM;
-    const uint store_r = tiw % TM;
-    const uint store_c = tiw / TM;
-#else
     const uint warp_i = gl_LocalInvocationID.x / WARP;
 
     const uint tiw = gl_LocalInvocationID.x % WARP;
 
     const uint tiwr = tiw % (WSUBM / TM);
     const uint tiwc = tiw / (WSUBM / TM);
-#endif
 
     const uint warp_r = warp_i % (BM / WM);
     const uint warp_c = warp_i / (BM / WM);
@@ -172,17 +147,27 @@ void main() {
     const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
 
 #ifdef MUL_MAT_ID
-    uint _ne1 = 0;
-    for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
-        for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+    if (bitCount(p.nei0) == 1) {
+        load_row_ids(expert_idx, true, ic);
+    } else {
+        load_row_ids(expert_idx, false, ic);
+    }
+#else
+    _ne1 = 0;
+    for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
+        for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
             if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
-                row_ids[_ne1] = u16vec2(ii0, ii1);
+                if (_ne1 >= ic * BN) {
+                    row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
+                }
                 _ne1++;
             }
         }
     }
 
     barrier();
+#endif
 
     // Workgroup has no work
     if (ic * BN >= _ne1) return;
@@ -209,159 +194,70 @@ void main() {
     uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
 #endif
 
-#ifdef COOPMAT
-    coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
-    coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
-    coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
-
-    coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
-
-    coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
-
-    [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
-        sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
-    }
-#else
-    int32_t cache_a_qs[WMITER * TM * BK / 4];
-
-    int32_t cache_b_qs[TN * BK / 4];
-
     ACC_TYPE sums[WMITER * TM * WNITER * TN];
 
     [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
         sums[i] = ACC_TYPE(0.0f);
     }
-#endif
 
-#if QUANT_AUXF == 1
-    FLOAT_TYPE cache_a_dm[WMITER * TM];
-#else
-    FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
-#endif
-
-    FLOAT_TYPE_VEC2 cache_b_ds[TN];
-
-    for (uint block = start_k; block < end_k; block += BK) {
+    for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
         [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
-            const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
-            const uint iqs = loadr_a;
             const uint buf_ib = loadc_a + l;
+            const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
+            const uint iqs = loadr_a;
 
-            if (iqs == 0) {
-#if QUANT_AUXF == 1
-                buf_a_dm[buf_ib] = get_d(ib);
-#else
-                buf_a_dm[buf_ib] = get_dm(ib);
-#endif
+            [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+                block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
             }
-#if QUANT_R == 1
-            buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
-#else
-            const i32vec2 vals = repack(ib, iqs);
-            buf_a_qs[buf_ib * SHMEM_STRIDE + iqs    ] = vals.x;
-            buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
-#endif
         }
         [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
+            const uint buf_ib = loadc_b + l;
+
 #ifdef MUL_MAT_ID
-            const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
-            const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
-            const uint ib = idx / 8;
-            const uint iqs = idx & 0x7;
+            const u16vec2 row_idx = row_ids[buf_ib];
+            const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
 #else
-            const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
-            const uint ib_outer = ib / 4;
-            const uint ib_inner = ib % 4;
-
-            const uint iqs = loadr_b;
+            const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
 #endif
+            const uint iqs = loadr_b;
 
-            const uint buf_ib = loadc_b + l;
-
-            if (iqs == 0) {
-                buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
+            [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+                block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
             }
-            const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
-            buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4    ] = values.x;
-            buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
-            buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
-            buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
         }
 
         barrier();
 
-        pos_a_ib += 1;
-        pos_b_ib += 1;
+        pos_a_ib += BK_STEP;
+        pos_b_ib += BK_STEP;
 
-#ifdef COOPMAT
-        [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
-            const uint ib_a = warp_r * WM + cm_row * TM;
+        for (uint k_step = 0; k_step < BK_STEP; k_step++) {
             // Load from shared into cache
-            coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
-
-            // TODO: only cache values that are actually needed
-            [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
-                cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
-            }
-
-            [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
-                const uint ib_b = warp_c * WN + cm_col * TN;
-                coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
-
-                // TODO: only cache values that are actually needed
-                [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
-                    cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
-                }
-
-                cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
-                cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
-
-                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
-                    coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
-                }
-
-                coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-                sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
-            }
-        }
-#else
-        // Load from shared into cache
-        [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-            [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-                const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
-                cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
-                [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
-                    cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
-                }
-            }
-        }
+            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+                [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+                    const uint reg_ib = wsir * TM + cr;
+                    const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
 
-        [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
-            [[unroll]] for (uint cc = 0; cc < TN; cc++) {
-                const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
-                cache_b_ds[cc] = buf_b_ds[ib];
-                [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
-                    cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
+                    block_a_to_registers(reg_ib, k_step * BM + buf_ib);
                 }
             }
 
-            [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
                 [[unroll]] for (uint cc = 0; cc < TN; cc++) {
-                    [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-                        const uint cache_a_idx = wsir * TM + cr;
-                        const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
-                        int32_t q_sum = 0;
-                        [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
-                            q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
-                                                     cache_b_qs[cc * (BK / 4) + idx_k]);
-                        }
+                    const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
+                    block_b_to_registers(ib);
 
-                        sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
+                    [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+                        [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+                            const uint cache_a_idx = wsir * TM + cr;
+                            const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+
+                            sums[sums_idx] += mmq_dot_product(cache_a_idx);
+                        }
                     }
                 }
             }
         }
-#endif
 
         barrier();
     }
@@ -373,54 +269,6 @@ void main() {
     const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
 #endif
 
-#ifdef COOPMAT
-#ifdef MUL_MAT_ID
-    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
-        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
-            coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
-            [[unroll]] for (uint col = 0; col < BN; col += storestride) {
-                const uint row_i = dc + cm_col * TN + col + store_c;
-                if (row_i >= _ne1) break;
-
-                const u16vec2 row_idx = row_ids[row_i];
-
-                data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
-            }
-        }
-    }
-#else
-    const bool is_aligned = p.stride_d % 4 == 0;  // Assumption: D_TYPE == float
-
-    [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
-        [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
-            const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
-
-            if (is_aligned && is_in_bounds) {
-                // Full coopMat is within bounds and stride_d is aligned with 16B
-                coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
-                coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
-            } else if (is_in_bounds) {
-                // Full coopMat is within bounds, but stride_d is not aligned
-                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
-                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
-                    data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
-                }
-            } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
-                // Partial coopMat is within bounds
-                coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
-                [[unroll]] for (uint col = 0; col < TN; col += storestride) {
-                    if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
-                        data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
-                    }
-                }
-            }
-        }
-    }
-#endif // MUL_MAT_ID
-#else
     [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
         [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
 
@@ -431,19 +279,21 @@ void main() {
                 const uint row_i = dc_warp + cc;
                 if (row_i >= _ne1) break;
 
-                const u16vec2 row_idx = row_ids[row_i];
+                const u16vec2 row_idx = row_ids[row_i - ic * BN];
 #endif // MUL_MAT_ID
                 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+                    const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
 #ifdef MUL_MAT_ID
-                    data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+                    if (dr_warp + cr < p.M) {
+                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
+                    }
 #else
                     if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
-                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
                     }
 #endif // MUL_MAT_ID
                 }
             }
         }
     }
-#endif // COOPMAT
 }
index fe71eb131c807b837aca86341f336ff91cc96452..c0c03fedcc222681065770a554344552bcfe7aa0 100644 (file)
@@ -6,41 +6,89 @@
 
 // Each iqs value maps to a 32-bit integer
 
-#if defined(DATA_A_Q4_0)
+#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
+// 2-byte loads for Q4_0 blocks (18 bytes)
+// 4-byte loads for Q4_1 blocks (20 bytes)
 i32vec2 repack(uint ib, uint iqs) {
-    // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
-    const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2    ],
-                                   data_a[ib].qs[iqs * 2 + 1]);
+#ifdef DATA_A_Q4_0
+    const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2    ],
+                                   data_a_packed16[ib].qs[iqs * 2 + 1]);
     const uint32_t vui = pack32(quants);
     return i32vec2( vui       & 0x0F0F0F0F,
                    (vui >> 4) & 0x0F0F0F0F);
+#else // DATA_A_Q4_1
+    const uint32_t vui = data_a_packed32[ib].qs[iqs];
+    return i32vec2( vui       & 0x0F0F0F0F,
+                   (vui >> 4) & 0x0F0F0F0F);
+#endif
 }
 
+#ifdef DATA_A_Q4_0
 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
     return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
 }
+#else // DATA_A_Q4_1
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+    return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+}
 #endif
 
-#if defined(DATA_A_Q4_1)
-i32vec2 repack(uint ib, uint iqs) {
-    // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
-    const uint32_t vui = data_a_packed32[ib].qs[iqs];
-    return i32vec2( vui       & 0x0F0F0F0F,
-                   (vui >> 4) & 0x0F0F0F0F);
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+#ifdef DATA_A_Q4_0
+    buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
+                                           data_a_packed16[ib].qs[iqs * 2 + 1]));
+
+    if (iqs == 0) {
+        buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+    }
+#else // DATA_A_Q4_1
+    buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
+
+    if (iqs == 0) {
+        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+    }
+#endif
 }
 
-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
-    return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+    }
 }
-#endif
 
-#if defined(DATA_A_Q5_0)
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+    int32_t q_sum = 0;
+    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+        const uint32_t vui = cache_a[ib_a].qs[iqs];
+        const i32vec2 qs_a = i32vec2( vui       & 0x0F0F0F0F,
+                                     (vui >> 4) & 0x0F0F0F0F);
+
+        const int32_t qs_b0 = cache_b.qs[iqs];
+        const int32_t qs_b1 = cache_b.qs[iqs + 4];
+
+        q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
+        q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
+    }
+
+    return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+
+#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
+// 2-byte loads for Q5_0 blocks (22 bytes)
+// 4-byte loads for Q5_1 blocks (24 bytes)
 i32vec2 repack(uint ib, uint iqs) {
-    // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
-    const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2    ],
-                                   data_a[ib].qs[iqs * 2 + 1]);
+    const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2    ],
+                                   data_a_packed16[ib].qs[iqs * 2 + 1]);
     const uint32_t vui = pack32(quants);
-    const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
+#ifdef DATA_A_Q5_0
+    const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
+#else // DATA_A_Q5_1
+    const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
+#endif
     const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
                      | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
 
@@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) {
     return i32vec2(v0, v1);
 }
 
+#ifdef DATA_A_Q5_0
 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
     return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
 }
+#else // DATA_A_Q5_1
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+    return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+}
 #endif
 
-#if defined(DATA_A_Q5_1)
-i32vec2 repack(uint ib, uint iqs) {
-    // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
-    const uint32_t vui = data_a_packed32[ib].qs[iqs];
-    const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
-    const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
-                     | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+#ifdef DATA_A_Q5_0
+    buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
+                                           data_a_packed16[ib].qs[iqs * 2 + 1]));
 
-    const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
-                     | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+    if (iqs == 0) {
+        buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+        buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
+    }
+#else // DATA_A_Q5_1
+    buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
 
-    return i32vec2(v0, v1);
+    if (iqs == 0) {
+        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+        buf_a[buf_ib].qh = data_a_packed32[ib].qh;
+    }
+#endif
 }
 
-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
-    return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+    cache_a[reg_ib].qh = buf_a[buf_ib].qh;
+
+    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+    }
 }
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+    int32_t q_sum = 0;
+    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+        const uint32_t vui = cache_a[ib_a].qs[iqs];
+        const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
+        const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
+                         | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+        const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
+                         | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+
+        const int32_t qs_b0 = cache_b.qs[iqs];
+        const int32_t qs_b1 = cache_b.qs[iqs + 4];
+
+        q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
+        q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
+    }
+
+    return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
 #endif
 
 #if defined(DATA_A_Q8_0)
+// 2-byte loads for Q8_0 blocks (34 bytes)
 int32_t repack(uint ib, uint iqs) {
-    // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
-    return pack32(i16vec2(data_a[ib].qs[iqs * 2    ],
-                          data_a[ib].qs[iqs * 2 + 1]));
+    return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2    ],
+                          data_a_packed16[ib].qs[iqs * 2 + 1]));
 }
 
 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
     return ACC_TYPE(float(q_sum) * da * dsb.x);
 }
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+    buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
+                                           data_a_packed16[ib].qs[iqs * 2 + 1]));
+
+    if (iqs == 0) {
+        buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+    }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+    }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+    int32_t q_sum = 0;
+    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+        const int32_t qs_a = cache_a[ib_a].qs[iqs];
+        const int32_t qs_b = cache_b.qs[iqs];
+
+        q_sum += dotPacked4x8EXT(qs_a, qs_b);
+    }
+
+    return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_MXFP4)
+// 1-byte loads for mxfp4 blocks (17 bytes)
+i32vec2 repack(uint ib, uint iqs) {
+    const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4    ],
+                                          data_a[ib].qs[iqs * 4 + 1],
+                                          data_a[ib].qs[iqs * 4 + 2],
+                                          data_a[ib].qs[iqs * 4 + 3]));
+
+    return i32vec2( quants       & 0x0F0F0F0F,
+                   (quants >> 4) & 0x0F0F0F0F);
+}
+
+ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
+    return ACC_TYPE(da * dsb.x * float(q_sum));
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+    const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4    ],
+                                      data_a[ib].qs[iqs * 4 + 1],
+                                      data_a[ib].qs[iqs * 4 + 2],
+                                      data_a[ib].qs[iqs * 4 + 3]));
+
+    const u8vec4 i_a0 = unpack8( qs       & 0x0F0F0F0F);
+    const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
+
+    buf_a[buf_ib].qs[iqs    ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
+    buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
+
+    if (iqs == 0) {
+        buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
+    }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+    cache_a[reg_ib].d = buf_a[buf_ib].d;
+
+    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+    }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+    int32_t q_sum = 0;
+    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+        const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+    }
+
+    return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
+// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
+#if defined(DATA_A_Q2_K)
+// 4-byte loads for Q2_K blocks (84 bytes)
+int32_t repack(uint ib, uint iqs) {
+    const uint ib_k = ib / 8;
+    const uint iqs_k = (ib % 8) * 8 + iqs;
+
+    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+    const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+
+    return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
+}
+
+uint8_t get_scale(uint ib, uint iqs) {
+    const uint ib_k = ib / 8;
+    const uint iqs_k = (ib % 8) * 8 + iqs;
+
+    return data_a[ib_k].scales[iqs_k / 4];
+}
+
+ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+    return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+    const uint ib_k = ib / 8;
+    const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
+
+    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+    const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+
+    // Repack 4x4 quants into one int
+    const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x03030303;
+    const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
+    const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
+    const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
+
+    buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
+
+    if (iqs == 0) {
+        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
+        buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
+    }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+    cache_a[reg_ib].scales = buf_a[buf_ib].scales;
+
+    [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
+        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+    }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+    int32_t sum_d = 0;
+    int32_t sum_m = 0;
+
+    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+        const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
+        const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
+        const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
+
+        sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
+        sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
+    }
+
+    return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_Q3_K)
+// 2-byte loads for Q3_K blocks (110 bytes)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+    const uint ib_k = ib / 8;
+    const uint hm_idx = iqs * QUANT_R_MMQ;
+    const uint iqs_k = (ib % 8) * 8 + hm_idx;
+
+    const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+    const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+    const uint hm_shift = iqs_k / 8;
+
+    // Repack 2x4 quants into one int
+    // Add the 3rd bit instead of subtracting it to allow packing the quants
+    const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2        ] >> qs_shift) & uint16_t(0x0303))) |
+                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2    ] >> hm_shift) & uint16_t(0x0101)) << 2));
+    const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1    ] >> qs_shift) & uint16_t(0x0303))) |
+                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
+    const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2    ] >> qs_shift) & uint16_t(0x0303))) |
+                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
+    const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3    ] >> qs_shift) & uint16_t(0x0303))) |
+                          unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
+    buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
+                           (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
+
+    if (iqs == 0) {
+        const uint is = iqs_k / 4;
+        const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8      ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
+                                            (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
+
+        buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
+    }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+    cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
+
+    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+    }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+    float result = 0.0;
+    int32_t q_sum = 0;
+
+    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+        // Subtract 4 from the quants to correct the 3rd bit offset
+        const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
+
+        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+    }
+    result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
+    q_sum = 0;
+
+    [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
+        const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
+
+        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+    }
+    result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
+
+    return ACC_TYPE(cache_b.ds.x * result);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
+// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+    return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+    const uint ib_k = ib / 8;
+    const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
+
+    const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
+    const uint qs_shift = ((iqs_k % 16) / 8) * 4;
+
+    // Repack 2x4 quants into one int
+#if defined(DATA_A_Q4_K)
+    const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x0F0F0F0F;
+    const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
+
+    buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
+#else // defined(DATA_A_Q5_K)
+    const uint qh_idx = iqs * QUANT_R_MMQ;
+    const uint qh_shift = iqs_k / 8;
+
+    buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
+                                   (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
+#endif
+
+
+    if (iqs == 0) {
+        // Scale index
+        const uint is = iqs_k / 8;
+        u8vec2 scale_dm;
+        if (is < 4) {
+            scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
+        } else {
+            scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
+                              (data_a[ib_k].scales[is+4] >>  4) | ((data_a[ib_k].scales[is  ] & 0xC0) >> 2));
+        }
+
+        buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
+    }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+    cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+    [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
+        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+    }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+    int32_t q_sum = 0;
+
+    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+#if defined(DATA_A_Q4_K)
+        const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
+#else // defined(DATA_A_Q5_K)
+        const int32_t qs_a = cache_a[ib_a].qs[iqs];
+#endif
+
+        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+    }
+
+    return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#ifdef MMQ_SHMEM
+void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+    const uint ib_outer = ib / 4;
+    const uint ib_inner = ib % 4;
+
+    if (iqs == 0) {
+        buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
+    }
+
+    const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
+    buf_b[buf_ib].qs[iqs * 4    ] = values.x;
+    buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
+    buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
+    buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
+}
+
+void block_b_to_registers(const uint ib) {
+    cache_b.ds = buf_b[ib].ds;
+    [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
+        cache_b.qs[iqs] = buf_b[ib].qs[iqs];
+    }
+}
+#endif
+
+#if defined(DATA_A_Q6_K)
+// 2-byte loads for Q6_K blocks (210 bytes)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+    const uint ib_k = ib / 8;
+    const uint iqs_k = (ib % 8) * 8 + iqs;
+
+    const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
+    const uint ql_shift = ((iqs_k % 32) / 16) * 4;
+
+    const uint qh_idx = (iqs_k / 32) * 8 + iqs;
+    const uint qh_shift = ((iqs_k % 32) / 8) * 2;
+
+    const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2    ] >> ql_shift) & uint16_t(0x0F0F))) |
+                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2    ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
+    const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
+                          unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
+    buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
+
+    if (iqs == 0) {
+        const uint is = iqs_k / 4;
+        const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
+
+        buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
+    }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+    cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
+
+    [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+        cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+    }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+    float result = 0.0;
+    int32_t q_sum = 0;
+
+    [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+        const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+    }
+    result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
+    q_sum = 0;
+
+    [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
+        const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+        q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+    }
+    result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
+
+    return ACC_TYPE(cache_b.ds.x * result);
+}
+#endif // MMQ_SHMEM
 #endif
 
 #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
@@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) {
     return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
 }
 #endif
+
+#if defined(DATA_A_Q2_K)
+FLOAT_TYPE_VEC2 get_dm(uint ib) {
+    const uint ib_k = ib / 8;
+    return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
+}
+#endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
new file mode 100644 (file)
index 0000000..72fec44
--- /dev/null
@@ -0,0 +1,78 @@
+#if defined(DATA_A_Q4_0)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+    uint32_t qs[16/4];
+    FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_Q4_1)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+    uint32_t qs[16/4];
+    FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q5_0)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+    uint32_t qs[16/4];
+    uint32_t qh;
+    FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_Q5_1)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+    uint32_t qs[16/4];
+    uint32_t qh;
+    FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q8_0)
+#define QUANT_R_MMQ 1
+// AMD likes 4, Intel likes 1 and Nvidia likes 2
+#define BK_STEP 1
+struct block_a_cache {
+    int32_t qs[32/4];
+    FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_MXFP4)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+    int32_t qs[8];
+    FLOAT_TYPE d;
+};
+#elif defined(DATA_A_Q2_K)
+#define QUANT_R_MMQ 4
+struct block_a_cache {
+    uint32_t qs[2];
+    u8vec2 scales;
+    FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q3_K)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+    uint32_t qs[4];
+    FLOAT_TYPE_VEC2 d_scales;
+};
+#elif defined(DATA_A_Q4_K)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+    uint32_t qs[4];
+    FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q5_K)
+#define QUANT_R_MMQ 1
+struct block_a_cache {
+    int32_t qs[8];
+    FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q6_K)
+#define QUANT_R_MMQ 1
+struct block_a_cache {
+    int32_t qs[8];
+    FLOAT_TYPE_VEC2 d_scales;
+};
+#endif
+
+struct block_b_cache
+{
+    int32_t qs[8];
+    FLOAT_TYPE_VEC2 ds;
+};
index 2fa54ce51fc83f156fe0895f0d1b7bc0717366b5..02578c77c4f310dd00843a7689f4566349406486 100644 (file)
@@ -66,6 +66,7 @@ struct block_q4_0_packed16
 #define QUANT_AUXF 1
 #define A_TYPE block_q4_0
 #define A_TYPE_PACKED16 block_q4_0_packed16
+#define DATA_A_QUANT_LEGACY
 #endif
 
 #define QUANT_K_Q4_1 32
@@ -98,6 +99,7 @@ struct block_q4_1_packed32
 #define A_TYPE block_q4_1
 #define A_TYPE_PACKED16 block_q4_1_packed16
 #define A_TYPE_PACKED32 block_q4_1_packed32
+#define DATA_A_QUANT_LEGACY
 #endif
 
 #define QUANT_K_Q5_0 32
@@ -123,6 +125,7 @@ struct block_q5_0_packed16
 #define QUANT_AUXF 1
 #define A_TYPE block_q5_0
 #define A_TYPE_PACKED16 block_q5_0_packed16
+#define DATA_A_QUANT_LEGACY
 #endif
 
 #define QUANT_K_Q5_1 32
@@ -158,6 +161,7 @@ struct block_q5_1_packed32
 #define A_TYPE block_q5_1
 #define A_TYPE_PACKED16 block_q5_1_packed16
 #define A_TYPE_PACKED32 block_q5_1_packed32
+#define DATA_A_QUANT_LEGACY
 #endif
 
 #define QUANT_K_Q8_0 32
@@ -186,6 +190,7 @@ struct block_q8_0_packed32
 #define A_TYPE block_q8_0
 #define A_TYPE_PACKED16 block_q8_0_packed16
 #define A_TYPE_PACKED32 block_q8_0_packed32
+#define DATA_A_QUANT_LEGACY
 #endif
 
 #define QUANT_K_Q8_1 32
@@ -226,21 +231,21 @@ struct block_q2_K
 {
     uint8_t scales[QUANT_K_Q2_K/16];
     uint8_t qs[QUANT_K_Q2_K/4];
-    f16vec2 d;
+    f16vec2 dm;
 };
 
 struct block_q2_K_packed16
 {
     uint16_t scales[QUANT_K_Q2_K/16/2];
     uint16_t qs[QUANT_K_Q2_K/4/2];
-    f16vec2 d;
+    f16vec2 dm;
 };
 
 struct block_q2_K_packed32
 {
     uint32_t scales[QUANT_K_Q2_K/16/4];
     uint32_t qs[QUANT_K_Q2_K/4/4];
-    f16vec2 d;
+    f16vec2 dm;
 };
 
 #if defined(DATA_A_Q2_K)
@@ -249,6 +254,8 @@ struct block_q2_K_packed32
 #define A_TYPE block_q2_K
 #define A_TYPE_PACKED16 block_q2_K_packed16
 #define A_TYPE_PACKED32 block_q2_K_packed32
+#define SCALES_PER_32 2
+#define DATA_A_QUANT_K
 #endif
 
 #define QUANT_K_Q3_K 256
@@ -274,27 +281,28 @@ struct block_q3_K_packed16
 #define QUANT_R 1
 #define A_TYPE block_q3_K
 #define A_TYPE_PACKED16 block_q3_K_packed16
+#define DATA_A_QUANT_K
 #endif
 
 #define QUANT_K_Q4_K 256
 
 struct block_q4_K
 {
-    f16vec2 d;
+    f16vec2 dm;
     uint8_t scales[3*QUANT_K_Q4_K/64];
     uint8_t qs[QUANT_K_Q4_K/2];
 };
 
 struct block_q4_K_packed16
 {
-    f16vec2 d;
+    f16vec2 dm;
     uint16_t scales[3*QUANT_K_Q4_K/64/2];
     uint16_t qs[QUANT_K_Q4_K/2/2];
 };
 
 struct block_q4_K_packed32
 {
-    f16vec2 d;
+    f16vec2 dm;
     uint32_t scales[3*QUANT_K_Q4_K/64/4];
     uint32_t qs[QUANT_K_Q4_K/2/4];
 };
@@ -310,13 +318,14 @@ struct block_q4_K_packed128
 #define A_TYPE block_q4_K
 #define A_TYPE_PACKED16 block_q4_K_packed16
 #define A_TYPE_PACKED32 block_q4_K_packed32
+#define DATA_A_QUANT_K
 #endif
 
 #define QUANT_K_Q5_K 256
 
 struct block_q5_K
 {
-    f16vec2 d;
+    f16vec2 dm;
     uint8_t scales[12];
     uint8_t qh[QUANT_K_Q5_K/8];
     uint8_t qs[QUANT_K_Q5_K/2];
@@ -324,12 +333,20 @@ struct block_q5_K
 
 struct block_q5_K_packed16
 {
-    f16vec2 d;
+    f16vec2 dm;
     uint16_t scales[12/2];
     uint16_t qh[QUANT_K_Q5_K/8/2];
     uint16_t qs[QUANT_K_Q5_K/2/2];
 };
 
+struct block_q5_K_packed32
+{
+    f16vec2 dm;
+    uint32_t scales[12/4];
+    uint32_t qh[QUANT_K_Q5_K/8/4];
+    uint32_t qs[QUANT_K_Q5_K/2/4];
+};
+
 struct block_q5_K_packed128
 {
     uvec4 q5k[11];
@@ -340,6 +357,8 @@ struct block_q5_K_packed128
 #define QUANT_R 1
 #define A_TYPE block_q5_K
 #define A_TYPE_PACKED16 block_q5_K_packed16
+#define A_TYPE_PACKED32 block_q5_K_packed32
+#define DATA_A_QUANT_K
 #endif
 
 #define QUANT_K_Q6_K 256
@@ -356,7 +375,7 @@ struct block_q6_K_packed16
 {
     uint16_t ql[QUANT_K_Q6_K/2/2];
     uint16_t qh[QUANT_K_Q6_K/4/2];
-    int8_t scales[QUANT_K_Q6_K/16];
+    int16_t scales[QUANT_K_Q6_K/16/2];
     float16_t d;
 };
 
@@ -365,6 +384,7 @@ struct block_q6_K_packed16
 #define QUANT_R 1
 #define A_TYPE block_q6_K
 #define A_TYPE_PACKED16 block_q6_K_packed16
+#define DATA_A_QUANT_K
 #endif
 
 // IQuants
@@ -1363,18 +1383,11 @@ struct block_mxfp4
     uint8_t qs[QUANT_K_MXFP4/2];
 };
 
-//struct block_mxfp4_packed16
-//{
-//    uint8_t e;
-//    uint16_t qs[QUANT_K_MXFP4/2/2];
-//};
-
 #if defined(DATA_A_MXFP4)
 #define QUANT_K QUANT_K_MXFP4
 #define QUANT_R QUANT_R_MXFP4
 #define QUANT_AUXF 1
 #define A_TYPE block_mxfp4
-//#define A_TYPE_PACKED16 block_mxfp4_packed16
 #endif
 
 #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
@@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize)
 #endif
 
 #if defined(DATA_A_MXFP4)
-const FLOAT_TYPE kvalues_mxfp4_const[16] = {
-    FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
-    FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
+const int8_t kvalues_mxfp4_const[16] = {
+    int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
+    int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
 };
 
-shared FLOAT_TYPE kvalues_mxfp4[16];
+shared int8_t kvalues_mxfp4[16];
 
 #define NEEDS_INIT_IQ_SHMEM
 void init_iq_shmem(uvec3 wgsize)
index 0f25ba3453093306a1965da6b7ab139601b728e8..03fa016398c189f3b80af6cd680a2e8580515f48 100644 (file)
@@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
         }
 
 #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
-        if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
+        // Integer dot mmq performs better with f32 accumulators
+        if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
             string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
         }
 #endif
@@ -574,7 +575,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
 }
 
 void process_shaders() {
-    std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
+    std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
 
     // matmul
     for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {