vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_quantize_q8_1;
l_warptile_id, m_warptile_id, s_warptile_id,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
+ l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
- l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
+ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
+ l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
+ l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
+ // Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
+ // K-quants use even more registers, mitigate by setting WMITER to 1
+ l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
+ m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
+ s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
+
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
+ l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
+ m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
+ s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
+
+ l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
+ m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
+ s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
+
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
- m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
+ m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
-#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
+#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _m[TYPE]) { \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _s[TYPE]) { \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
// Create 2 variants, {f16,f32} accumulator
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
}
#endif
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ if (device->integer_dot_product) {
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
+ }
+#endif
} else {
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+ if (device->integer_dot_product) {
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
+ }
+#endif
}
#undef CREATE_MM2
#undef CREATE_MMQ
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
+
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
}
#endif
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
-#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
+#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
&& !device->coopmat_bf16_support
#endif
) {
// MMQ
if (src1_type == GGML_TYPE_Q8_1) {
- vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
return nullptr;
}
}
+ // MMQ
+ if (src1_type == GGML_TYPE_Q8_1) {
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
+
+ if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
+ return nullptr;
+ }
+
+ return pipelines;
+ }
+
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
switch (src0_type) {
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
+
+ // Check for mmq first
+ vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
+
+ if (mmp == nullptr) {
+ // Fall back to f16 dequant mul mat
+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
+ quantize_y = false;
+ }
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
- const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
if (qx_needs_dequant) {
// Fall back to dequant + f16 mulmat
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
- const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
+ const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
const uint64_t ids_sz = nbi2;
const uint64_t d_sz = sizeof(float) * d_ne;
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
+ vk_pipeline to_q8_1 = nullptr;
if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
+ if (quantize_y) {
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
+ }
+
if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
+ }
if (
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
ctx->prealloc_size_x = x_sz_upd;
}
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
+ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = y_sz_upd;
}
if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
}
+ if (quantize_y) {
+ ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
+ }
return;
}
if (qy_needs_dequant) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
+ } else if (quantize_y) {
+ d_Y = ctx->prealloc_y;
+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
ctx->prealloc_y_last_tensor_used = src1;
}
}
+ if (quantize_y) {
+ if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
+ ctx->prealloc_y_last_tensor_used != src1) {
+ if (ctx->prealloc_y_need_sync) {
+ ggml_vk_sync_buffers(ctx, subctx);
+ }
+ ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
+ ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
+ ctx->prealloc_y_last_tensor_used = src1;
+ }
+ }
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
}
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
+ uint32_t y_sz_total = y_sz * ne12 * ne13;
+ if (quantize_y) {
+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
+ }
+
// compute
ggml_vk_matmul_id(
ctx, subctx, pipeline,
- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
ne01, ne21, ne10, ne10, ne10, ne01,
stride_batch_x, stride_batch_y, ne20*ne21,
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
- return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
+ return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5;
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
vec2 v0 = dequantize(ib, iqs, a_offset);
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
const uint scales = data_a[a_offset + ib].scales[scalesi];
- const vec2 d = vec2(data_a[a_offset + ib].d);
+ const vec2 dm = vec2(data_a[a_offset + ib].dm);
- return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
+ return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
}
vec2 get_dm(uint ib, uint a_offset) {
return vec2(1, 0);
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
- const vec2 loadd = vec2(data_a[a_offset + ib].d);
+ const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint8_t hm = uint8_t(1 << (iqs / 16));
- const vec2 loadd = vec2(data_a[a_offset + ib].d);
+ const vec2 loadd = vec2(data_a[a_offset + ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
- const f16vec2 d = bl.block.d;
+ const f16vec2 dm = bl.block.dm;
const uint idx = coordInBlock[1];
const uint scalesi = (idx & 0xF0) >> 4; // 0..15
qs = unpack8(qs)[idx & 1];
const uint scales = bl.block.scales[scalesi];
- float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4);
+ float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
return ret;
}
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
- float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
+ float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
return ret;
}
#endif
const float d = e8m0_to_fp32(data_a[ib].e);
[[unroll]] for (uint l = 0; l < 8; ++l) {
- data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
- data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
+ data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]));
+ data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]));
}
}
const uint ql_idx = 32 * ip + il;
const uint8_t qs = data_a[i].qs[32 * ip + il];
- FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
- FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
+ FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x);
+ FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y);
data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
const uint is = 2 * il;
const uint n = 4;
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + n * ir;
const uint qs_idx = 32*il + n * ir;
const uint ir = tid % 16;
const uint is = 2 * il;
- const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y);
+ const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x);
+ const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y);
const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir;
const uint qs_idx = 32*il + 2 * ir;
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
- vec2 d = vec2(data_a[ib0 + i].d);
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+ const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
- temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
+ temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n]));
}
}
}
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
- vec2 d = vec2(data_a[ib0 + i].d);
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+ const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+ temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
- vec2 d = vec2(data_a[ib0 + i].d);
- const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
- const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
+ const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
- temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
+ temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n]));
}
}
}
#define NUM_WARPS (BLOCK_SIZE / WARP)
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[BN];
-uint _ne1;
-
-#ifdef MUL_MAT_ID_USE_SUBGROUPS
-shared uvec4 ballots_sh[NUM_WARPS];
-
-void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
- _ne1 = 0;
- uint num_elements = p.nei1 * p.nei0;
- uint nei0shift = findLSB(p.nei0);
-
- uint ids[16];
- uint iter = 0;
-
- for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
- // prefetch up to 16 elements
- if (iter == 0) {
- [[unroll]] for (uint k = 0; k < 16; ++k) {
- uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
- }
- }
- uint i = j + gl_LocalInvocationIndex;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- uint id = ids[iter++];
- uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
-
- ballots_sh[gl_SubgroupID] = ballot;
- barrier();
-
- uint subgroup_base = 0;
- uint total = 0;
- for (uint k = 0; k < gl_NumSubgroups; ++k) {
- if (k == gl_SubgroupID) {
- subgroup_base = total;
- }
- total += subgroupBallotBitCount(ballots_sh[k]);
- }
- barrier();
-
- uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
- row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
- }
- _ne1 += total;
- iter &= 15;
- if (_ne1 >= (ic + 1) * BN) {
- break;
- }
- }
- barrier();
-}
-#endif // MUL_MAT_ID_USE_SUBGROUPS
-#endif // MUL_MAT_ID
-
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
+#include "mul_mm_id_funcs.glsl"
#include "mul_mm_funcs.glsl"
void main() {
const uint ib = idx / 128; // 2 values per idx
const uint iqs = idx % 128; // 0..127
- const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
+ const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
- const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
+ const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const uint scales = data_a[ib].scales[scalesi];
- const vec2 d = vec2(data_a[ib].d);
+ const vec2 dm = vec2(data_a[ib].dm);
- const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
+ const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q3_K)
const uint is = 2 * n + b; // 0..7
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
- const vec2 loadd = vec2(data_a[ib].d);
+ const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint8_t hm = uint8_t(1 << (iqs / 16));
- const vec2 loadd = vec2(data_a[ib].d);
+ const vec2 loadd = vec2(data_a[ib].dm);
const uint scidx0 = (is < 4) ? is : (is + 4);
const uint scidx1 = (is < 4) ? is : (is - 4);
const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2;
- const float d = e8m0_to_fp32(data_a[ib].e);
+ const float d = e8m0_to_fp32(data_a[ib].e) * 0.5;
const uint vui = uint(data_a[ib].qs[iqs]);
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
--- /dev/null
+#ifdef MUL_MAT_ID
+shared u16vec2 row_ids[BN];
+uint _ne1;
+
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+shared uvec4 ballots_sh[NUM_WARPS];
+
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+ uint nei0shift = findLSB(p.nei0);
+
+ uint ids[16];
+ uint iter = 0;
+
+ for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_LocalInvocationIndex;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ uint id = ids[iter++];
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+ ballots_sh[gl_SubgroupID] = ballot;
+ barrier();
+
+ uint subgroup_base = 0;
+ uint total = 0;
+ for (uint k = 0; k < gl_NumSubgroups; ++k) {
+ if (k == gl_SubgroupID) {
+ subgroup_base = total;
+ }
+ total += subgroupBallotBitCount(ballots_sh[k]);
+ }
+ barrier();
+
+ uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+ row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
+ }
+ _ne1 += total;
+ iter &= 15;
+ if (_ne1 >= (ic + 1) * BN) {
+ break;
+ }
+ }
+ barrier();
+}
+#endif // MUL_MAT_ID_USE_SUBGROUPS
+#endif // MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
-#ifdef COOPMAT
-#extension GL_KHR_cooperative_matrix : enable
-#extension GL_KHR_memory_scope_semantics : enable
+#if defined(MUL_MAT_ID_USE_SUBGROUPS)
#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
#endif
#ifdef MUL_MAT_ID
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#if defined(A_TYPE_PACKED16)
+layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
+#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#define BK 32
-#ifdef COOPMAT
-#define SHMEM_STRIDE (BK / 4 + 4)
-#else
-#define SHMEM_STRIDE (BK / 4 + 1)
-#endif
+#define MMQ_SHMEM
-shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
+#include "mul_mmq_shmem_types.glsl"
-#ifndef COOPMAT
-#if QUANT_AUXF == 1
-shared FLOAT_TYPE buf_a_dm[BM];
-#else
-shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
-#endif
+#ifndef BK_STEP
+#define BK_STEP 4
#endif
-shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
-#ifndef COOPMAT
-shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
-#endif
+// Shared memory cache
+shared block_a_cache buf_a[BM * BK_STEP];
+shared block_b_cache buf_b[BN * BK_STEP];
+// Register cache
+block_a_cache cache_a[WMITER * TM];
+block_b_cache cache_b;
-#define LOAD_VEC_A (4 * QUANT_R)
+#define LOAD_VEC_A (4 * QUANT_R_MMQ)
#define LOAD_VEC_B 16
-#ifdef MUL_MAT_ID
-shared u16vec2 row_ids[4096];
-#endif // MUL_MAT_ID
-
#define NUM_WARPS (BLOCK_SIZE / WARP)
-#ifdef COOPMAT
-shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
-#endif
-
+#include "mul_mm_id_funcs.glsl"
#include "mul_mmq_funcs.glsl"
void main() {
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
-
-#ifdef COOPMAT
- const uint warp_i = gl_SubgroupID;
-
- const uint tiw = gl_SubgroupInvocationID;
-
- const uint cms_per_row = WM / TM;
- const uint cms_per_col = WN / TN;
-
- const uint storestride = WARP / TM;
- const uint store_r = tiw % TM;
- const uint store_c = tiw / TM;
-#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
-#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
- uint _ne1 = 0;
- for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
- for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+ if (bitCount(p.nei0) == 1) {
+ load_row_ids(expert_idx, true, ic);
+ } else {
+ load_row_ids(expert_idx, false, ic);
+ }
+#else
+ _ne1 = 0;
+ for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
+ for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
- row_ids[_ne1] = u16vec2(ii0, ii1);
+ if (_ne1 >= ic * BN) {
+ row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
+ }
_ne1++;
}
}
}
barrier();
+#endif
// Workgroup has no work
if (ic * BN >= _ne1) return;
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
-#ifdef COOPMAT
- coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
- coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
- coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
-
- coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
-
- coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
-
- [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
- sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
- }
-#else
- int32_t cache_a_qs[WMITER * TM * BK / 4];
-
- int32_t cache_b_qs[TN * BK / 4];
-
ACC_TYPE sums[WMITER * TM * WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
-#endif
-#if QUANT_AUXF == 1
- FLOAT_TYPE cache_a_dm[WMITER * TM];
-#else
- FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
-#endif
-
- FLOAT_TYPE_VEC2 cache_b_ds[TN];
-
- for (uint block = start_k; block < end_k; block += BK) {
+ for (uint block = start_k; block < end_k; block += BK * BK_STEP) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
- const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
- const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
+ const uint ib = pos_a_ib + buf_ib * p.stride_a / BK;
+ const uint iqs = loadr_a;
- if (iqs == 0) {
-#if QUANT_AUXF == 1
- buf_a_dm[buf_ib] = get_d(ib);
-#else
- buf_a_dm[buf_ib] = get_dm(ib);
-#endif
+ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+ block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs);
}
-#if QUANT_R == 1
- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
-#else
- const i32vec2 vals = repack(ib, iqs);
- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
- buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
-#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
+ const uint buf_ib = loadc_b + l;
+
#ifdef MUL_MAT_ID
- const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
- const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
- const uint ib = idx / 8;
- const uint iqs = idx & 0x7;
+ const u16vec2 row_idx = row_ids[buf_ib];
+ const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK;
#else
- const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
- const uint ib_outer = ib / 4;
- const uint ib_inner = ib % 4;
-
- const uint iqs = loadr_b;
+ const uint ib = pos_b_ib + buf_ib * p.stride_b / BK;
#endif
+ const uint iqs = loadr_b;
- const uint buf_ib = loadc_b + l;
-
- if (iqs == 0) {
- buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
+ [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) {
+ block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs);
}
- const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
- buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
}
barrier();
- pos_a_ib += 1;
- pos_b_ib += 1;
+ pos_a_ib += BK_STEP;
+ pos_b_ib += BK_STEP;
-#ifdef COOPMAT
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- const uint ib_a = warp_r * WM + cm_row * TM;
+ for (uint k_step = 0; k_step < BK_STEP; k_step++) {
// Load from shared into cache
- coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
-
- // TODO: only cache values that are actually needed
- [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
- cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
- }
-
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- const uint ib_b = warp_c * WN + cm_col * TN;
- coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
-
- // TODO: only cache values that are actually needed
- [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
- cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
- }
-
- cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
- cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
-
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
- }
-
- coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
- sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
- }
- }
-#else
- // Load from shared into cache
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
- cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
- cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
- }
- }
- }
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint reg_ib = wsir * TM + cr;
+ const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
- const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
- cache_b_ds[cc] = buf_b_ds[ib];
- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
- cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
+ block_a_to_registers(reg_ib, k_step * BM + buf_ib);
}
}
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- const uint cache_a_idx = wsir * TM + cr;
- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
- int32_t q_sum = 0;
- [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
- q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
- cache_b_qs[cc * (BK / 4) + idx_k]);
- }
+ const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
+ block_b_to_registers(ib);
- sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+ [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint cache_a_idx = wsir * TM + cr;
+ const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
+
+ sums[sums_idx] += mmq_dot_product(cache_a_idx);
+ }
}
}
}
}
-#endif
barrier();
}
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
-#ifdef COOPMAT
-#ifdef MUL_MAT_ID
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
- [[unroll]] for (uint col = 0; col < BN; col += storestride) {
- const uint row_i = dc + cm_col * TN + col + store_c;
- if (row_i >= _ne1) break;
-
- const u16vec2 row_idx = row_ids[row_i];
-
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- }
- }
-#else
- const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
-
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
-
- if (is_aligned && is_in_bounds) {
- // Full coopMat is within bounds and stride_d is aligned with 16B
- coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
- coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
- } else if (is_in_bounds) {
- // Full coopMat is within bounds, but stride_d is not aligned
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
- // Partial coopMat is within bounds
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
-
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- }
- }
- }
- }
-#endif // MUL_MAT_ID
-#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
- const u16vec2 row_idx = row_ids[row_i];
+ const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
+ const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr;
#ifdef MUL_MAT_ID
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ if (dr_warp + cr < p.M) {
+ data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
+ }
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+ data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x);
}
#endif // MUL_MAT_ID
}
}
}
}
-#endif // COOPMAT
}
// Each iqs value maps to a 32-bit integer
-#if defined(DATA_A_Q4_0)
+#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
+// 2-byte loads for Q4_0 blocks (18 bytes)
+// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
- // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
- const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
- data_a[ib].qs[iqs * 2 + 1]);
+#ifdef DATA_A_Q4_0
+ const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
+ data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
+#else // DATA_A_Q4_1
+ const uint32_t vui = data_a_packed32[ib].qs[iqs];
+ return i32vec2( vui & 0x0F0F0F0F,
+ (vui >> 4) & 0x0F0F0F0F);
+#endif
}
+#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
+#else // DATA_A_Q4_1
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+}
#endif
-#if defined(DATA_A_Q4_1)
-i32vec2 repack(uint ib, uint iqs) {
- // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
- const uint32_t vui = data_a_packed32[ib].qs[iqs];
- return i32vec2( vui & 0x0F0F0F0F,
- (vui >> 4) & 0x0F0F0F0F);
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+#ifdef DATA_A_Q4_0
+ buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+ }
+#else // DATA_A_Q4_1
+ buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+ }
+#endif
}
-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
}
-#endif
-#if defined(DATA_A_Q5_0)
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ const uint32_t vui = cache_a[ib_a].qs[iqs];
+ const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F,
+ (vui >> 4) & 0x0F0F0F0F);
+
+ const int32_t qs_b0 = cache_b.qs[iqs];
+ const int32_t qs_b1 = cache_b.qs[iqs + 4];
+
+ q_sum += dotPacked4x8EXT(qs_a.x, qs_b0);
+ q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+
+#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
+// 2-byte loads for Q5_0 blocks (22 bytes)
+// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
- // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
- const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
- data_a[ib].qs[iqs * 2 + 1]);
+ const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
+ data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
- const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
+#ifdef DATA_A_Q5_0
+ const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
+#else // DATA_A_Q5_1
+ const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
+#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
return i32vec2(v0, v1);
}
+#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
+#else // DATA_A_Q5_1
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+}
#endif
-#if defined(DATA_A_Q5_1)
-i32vec2 repack(uint ib, uint iqs) {
- // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
- const uint32_t vui = data_a_packed32[ib].qs[iqs];
- const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
- const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
- | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+#ifdef DATA_A_Q5_0
+ buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
- const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
- | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+ buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1]));
+ }
+#else // DATA_A_Q5_1
+ buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs];
- return i32vec2(v0, v1);
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
+ buf_a[buf_ib].qh = data_a_packed32[ib].qh;
+ }
+#endif
}
-ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
- return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+ cache_a[reg_ib].qh = buf_a[buf_ib].qh;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ const uint32_t vui = cache_a[ib_a].qs[iqs];
+ const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs));
+ const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F)
+ | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
+ const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F)
+ | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
+
+ const int32_t qs_b0 = cache_b.qs[iqs];
+ const int32_t qs_b1 = cache_b.qs[iqs + 4];
+
+ q_sum += dotPacked4x8EXT(qs_a0, qs_b0);
+ q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
+// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
- // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
- return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
- data_a[ib].qs[iqs * 2 + 1]));
+ return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
+ data_a_packed16[ib].qs[iqs * 2 + 1]));
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+ const int32_t qs_b = cache_b.qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, qs_b);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_MXFP4)
+// 1-byte loads for mxfp4 blocks (17 bytes)
+i32vec2 repack(uint ib, uint iqs) {
+ const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
+ data_a[ib].qs[iqs * 4 + 1],
+ data_a[ib].qs[iqs * 4 + 2],
+ data_a[ib].qs[iqs * 4 + 3]));
+
+ return i32vec2( quants & 0x0F0F0F0F,
+ (quants >> 4) & 0x0F0F0F0F);
+}
+
+ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(da * dsb.x * float(q_sum));
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
+ data_a[ib].qs[iqs * 4 + 1],
+ data_a[ib].qs[iqs * 4 + 2],
+ data_a[ib].qs[iqs * 4 + 3]));
+
+ const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
+ const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
+
+ buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w]));
+ buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]));
+
+ if (iqs == 0) {
+ buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].d = buf_a[buf_ib].d;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
+// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
+#if defined(DATA_A_Q2_K)
+// 4-byte loads for Q2_K blocks (84 bytes)
+int32_t repack(uint ib, uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs;
+
+ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+
+ return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
+}
+
+uint8_t get_scale(uint ib, uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs;
+
+ return data_a[ib_k].scales[iqs_k / 4];
+}
+
+ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
+
+ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+
+ // Repack 4x4 quants into one int
+ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303;
+ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303;
+ const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303;
+ const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303;
+
+ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6);
+
+ if (iqs == 0) {
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
+ buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+ cache_a[reg_ib].scales = buf_a[buf_ib].scales;
+
+ [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t sum_d = 0;
+ int32_t sum_m = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ const uint8_t scale = cache_a[ib_a].scales[iqs / 4];
+ const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
+ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303);
+
+ sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF);
+ sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
+ }
+
+ return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_Q3_K)
+// 2-byte loads for Q3_K blocks (110 bytes)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint hm_idx = iqs * QUANT_R_MMQ;
+ const uint iqs_k = (ib % 8) * 8 + hm_idx;
+
+ const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 32) / 8) * 2;
+ const uint hm_shift = iqs_k / 8;
+
+ // Repack 2x4 quants into one int
+ // Add the 3rd bit instead of subtracting it to allow packing the quants
+ const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
+ const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
+ const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
+ const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
+ buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) |
+ (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4);
+
+ if (iqs == 0) {
+ const uint is = iqs_k / 4;
+ const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) |
+ (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4)));
+
+ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ float result = 0.0;
+ int32_t q_sum = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ // Subtract 4 from the quants to correct the 3rd bit offset
+ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
+ q_sum = 0;
+
+ [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
+ const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4));
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
+
+ return ACC_TYPE(cache_b.ds.x * result);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
+// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
+ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
+ return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
+}
+
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
+
+ const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
+ const uint qs_shift = ((iqs_k % 16) / 8) * 4;
+
+ // Repack 2x4 quants into one int
+#if defined(DATA_A_Q4_K)
+ const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
+ const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
+
+ buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4);
+#else // defined(DATA_A_Q5_K)
+ const uint qh_idx = iqs * QUANT_R_MMQ;
+ const uint qh_shift = iqs_k / 8;
+
+ buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) |
+ (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
+#endif
+
+
+ if (iqs == 0) {
+ // Scale index
+ const uint is = iqs_k / 8;
+ u8vec2 scale_dm;
+ if (is < 4) {
+ scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
+ } else {
+ scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
+ (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
+ }
+
+ buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].dm = buf_a[buf_ib].dm;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ int32_t q_sum = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+#if defined(DATA_A_Q4_K)
+ const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F);
+#else // defined(DATA_A_Q5_K)
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+#endif
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+
+ return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
+}
+#endif // MMQ_SHMEM
+#endif
+
+#ifdef MMQ_SHMEM
+void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_outer = ib / 4;
+ const uint ib_inner = ib % 4;
+
+ if (iqs == 0) {
+ buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
+ }
+
+ const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
+ buf_b[buf_ib].qs[iqs * 4 ] = values.x;
+ buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
+ buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
+ buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
+}
+
+void block_b_to_registers(const uint ib) {
+ cache_b.ds = buf_b[ib].ds;
+ [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
+ cache_b.qs[iqs] = buf_b[ib].qs[iqs];
+ }
+}
+#endif
+
+#if defined(DATA_A_Q6_K)
+// 2-byte loads for Q6_K blocks (210 bytes)
+#ifdef MMQ_SHMEM
+void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
+ const uint ib_k = ib / 8;
+ const uint iqs_k = (ib % 8) * 8 + iqs;
+
+ const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
+ const uint ql_shift = ((iqs_k % 32) / 16) * 4;
+
+ const uint qh_idx = (iqs_k / 32) * 8 + iqs;
+ const uint qh_shift = ((iqs_k % 32) / 8) * 2;
+
+ const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
+ const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
+ unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
+ buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y));
+
+ if (iqs == 0) {
+ const uint is = iqs_k / 4;
+ const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]);
+
+ buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales);
+ }
+}
+
+void block_a_to_registers(const uint reg_ib, const uint buf_ib) {
+ cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales;
+
+ [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) {
+ cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs];
+ }
+}
+
+ACC_TYPE mmq_dot_product(const uint ib_a) {
+ float result = 0.0;
+ int32_t q_sum = 0;
+
+ [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[0]) * float(q_sum);
+ q_sum = 0;
+
+ [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) {
+ const int32_t qs_a = cache_a[ib_a].qs[iqs];
+
+ q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
+ }
+ result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
+
+ return ACC_TYPE(cache_b.ds.x * result);
+}
+#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
+
+#if defined(DATA_A_Q2_K)
+FLOAT_TYPE_VEC2 get_dm(uint ib) {
+ const uint ib_k = ib / 8;
+ return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
+}
+#endif
--- /dev/null
+#if defined(DATA_A_Q4_0)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_Q4_1)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q5_0)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ uint32_t qh;
+ FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_Q5_1)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[16/4];
+ uint32_t qh;
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q8_0)
+#define QUANT_R_MMQ 1
+// AMD likes 4, Intel likes 1 and Nvidia likes 2
+#define BK_STEP 1
+struct block_a_cache {
+ int32_t qs[32/4];
+ FLOAT_TYPE dm;
+};
+#elif defined(DATA_A_MXFP4)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ int32_t qs[8];
+ FLOAT_TYPE d;
+};
+#elif defined(DATA_A_Q2_K)
+#define QUANT_R_MMQ 4
+struct block_a_cache {
+ uint32_t qs[2];
+ u8vec2 scales;
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q3_K)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[4];
+ FLOAT_TYPE_VEC2 d_scales;
+};
+#elif defined(DATA_A_Q4_K)
+#define QUANT_R_MMQ 2
+struct block_a_cache {
+ uint32_t qs[4];
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q5_K)
+#define QUANT_R_MMQ 1
+struct block_a_cache {
+ int32_t qs[8];
+ FLOAT_TYPE_VEC2 dm;
+};
+#elif defined(DATA_A_Q6_K)
+#define QUANT_R_MMQ 1
+struct block_a_cache {
+ int32_t qs[8];
+ FLOAT_TYPE_VEC2 d_scales;
+};
+#endif
+
+struct block_b_cache
+{
+ int32_t qs[8];
+ FLOAT_TYPE_VEC2 ds;
+};
#define QUANT_AUXF 1
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q4_1 32
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#define A_TYPE_PACKED32 block_q4_1_packed32
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_0 32
#define QUANT_AUXF 1
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q5_1 32
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#define A_TYPE_PACKED32 block_q5_1_packed32
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_0 32
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
+#define DATA_A_QUANT_LEGACY
#endif
#define QUANT_K_Q8_1 32
{
uint8_t scales[QUANT_K_Q2_K/16];
uint8_t qs[QUANT_K_Q2_K/4];
- f16vec2 d;
+ f16vec2 dm;
};
struct block_q2_K_packed16
{
uint16_t scales[QUANT_K_Q2_K/16/2];
uint16_t qs[QUANT_K_Q2_K/4/2];
- f16vec2 d;
+ f16vec2 dm;
};
struct block_q2_K_packed32
{
uint32_t scales[QUANT_K_Q2_K/16/4];
uint32_t qs[QUANT_K_Q2_K/4/4];
- f16vec2 d;
+ f16vec2 dm;
};
#if defined(DATA_A_Q2_K)
#define A_TYPE block_q2_K
#define A_TYPE_PACKED16 block_q2_K_packed16
#define A_TYPE_PACKED32 block_q2_K_packed32
+#define SCALES_PER_32 2
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q3_K 256
#define QUANT_R 1
#define A_TYPE block_q3_K
#define A_TYPE_PACKED16 block_q3_K_packed16
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q4_K 256
struct block_q4_K
{
- f16vec2 d;
+ f16vec2 dm;
uint8_t scales[3*QUANT_K_Q4_K/64];
uint8_t qs[QUANT_K_Q4_K/2];
};
struct block_q4_K_packed16
{
- f16vec2 d;
+ f16vec2 dm;
uint16_t scales[3*QUANT_K_Q4_K/64/2];
uint16_t qs[QUANT_K_Q4_K/2/2];
};
struct block_q4_K_packed32
{
- f16vec2 d;
+ f16vec2 dm;
uint32_t scales[3*QUANT_K_Q4_K/64/4];
uint32_t qs[QUANT_K_Q4_K/2/4];
};
#define A_TYPE block_q4_K
#define A_TYPE_PACKED16 block_q4_K_packed16
#define A_TYPE_PACKED32 block_q4_K_packed32
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q5_K 256
struct block_q5_K
{
- f16vec2 d;
+ f16vec2 dm;
uint8_t scales[12];
uint8_t qh[QUANT_K_Q5_K/8];
uint8_t qs[QUANT_K_Q5_K/2];
struct block_q5_K_packed16
{
- f16vec2 d;
+ f16vec2 dm;
uint16_t scales[12/2];
uint16_t qh[QUANT_K_Q5_K/8/2];
uint16_t qs[QUANT_K_Q5_K/2/2];
};
+struct block_q5_K_packed32
+{
+ f16vec2 dm;
+ uint32_t scales[12/4];
+ uint32_t qh[QUANT_K_Q5_K/8/4];
+ uint32_t qs[QUANT_K_Q5_K/2/4];
+};
+
struct block_q5_K_packed128
{
uvec4 q5k[11];
#define QUANT_R 1
#define A_TYPE block_q5_K
#define A_TYPE_PACKED16 block_q5_K_packed16
+#define A_TYPE_PACKED32 block_q5_K_packed32
+#define DATA_A_QUANT_K
#endif
#define QUANT_K_Q6_K 256
{
uint16_t ql[QUANT_K_Q6_K/2/2];
uint16_t qh[QUANT_K_Q6_K/4/2];
- int8_t scales[QUANT_K_Q6_K/16];
+ int16_t scales[QUANT_K_Q6_K/16/2];
float16_t d;
};
#define QUANT_R 1
#define A_TYPE block_q6_K
#define A_TYPE_PACKED16 block_q6_K_packed16
+#define DATA_A_QUANT_K
#endif
// IQuants
uint8_t qs[QUANT_K_MXFP4/2];
};
-//struct block_mxfp4_packed16
-//{
-// uint8_t e;
-// uint16_t qs[QUANT_K_MXFP4/2/2];
-//};
-
#if defined(DATA_A_MXFP4)
#define QUANT_K QUANT_K_MXFP4
#define QUANT_R QUANT_R_MXFP4
#define QUANT_AUXF 1
#define A_TYPE block_mxfp4
-//#define A_TYPE_PACKED16 block_mxfp4_packed16
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
#endif
#if defined(DATA_A_MXFP4)
-const FLOAT_TYPE kvalues_mxfp4_const[16] = {
- FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
- FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
+const int8_t kvalues_mxfp4_const[16] = {
+ int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12),
+ int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12),
};
-shared FLOAT_TYPE kvalues_mxfp4[16];
+shared int8_t kvalues_mxfp4[16];
#define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize)
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
- if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
+ // Integer dot mmq performs better with f32 accumulators
+ if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
}
void process_shaders() {
- std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
+ std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
// matmul
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {