#endif
static void ggml_vk_destroy_buffer(vk_buffer& buf);
+static constexpr uint32_t mul_mat_vec_max_cols = 8;
+
struct vk_device_struct {
std::mutex mutex;
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
- vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
- vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
+ vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
rm_stdq = 2;
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
-
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
+ for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
+
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
+ }
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
}
-static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
+static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
+ GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
switch (a_type) {
case GGML_TYPE_F32:
return nullptr;
}
- return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
+ return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1];
}
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
const uint64_t ne12 = src1->ne[2];
const uint64_t ne13 = src1->ne[3];
- GGML_ASSERT(ne11 == 1);
-
const uint64_t ne20 = dst->ne[0];
const uint64_t ne21 = dst->ne[1];
const uint64_t ne22 = dst->ne[2];
const uint64_t r2 = ne12 / ne02;
const uint64_t r3 = ne13 / ne03;
+ // batch_n indicates that we need to compute a few vector results, and this assumes
+ // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
+ GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
+ bool batch_n = ne11 > 1;
+
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
} else {
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
}
- vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
GGML_ASSERT(dmmv != nullptr);
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
}
- uint32_t stride_batch_x = ne00*ne01;
- uint32_t stride_batch_y = ne10*ne11;
+ // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
+ uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
+ uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
+ uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
// compute
const vk_mat_vec_push_constants pc = {
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
- stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
+ stride_batch_x, stride_batch_y, stride_batch_d,
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
};
ggml_vk_sync_buffers(subctx);
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
- } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
+ // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
+ // when ne12 and ne13 are one.
+ } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
} else {
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#define K_PER_ITER 8
#else
uint a_offset, b_offset, d_offset, y_offset;
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
-void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
+void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
- const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
- const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
- const uint iybs = col - col%QUANT_K; // y block start index
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
+ const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
+ const uint iybs = col - col%QUANT_K; // y block start index
#if K_PER_ITER == 8
#if QUANT_R == 2
- const B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
- const B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
- const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
- const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
+ const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
+ const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
+ const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
+ const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
#else
- const vec4 bv0 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4]);
- const vec4 bv1 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4 + 1]);
+ const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
+ const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
#endif
#else
- // Check if the second of the pair of elements is OOB, and don't fetch B or
- // accumulate it. We still fetch a pair of elements for A, which is fine for
- // quantized formats since they'll be within the same block. We should
- // probably skip fetching the second element for F16/F32, but as of now we
- // still do.
- const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
-
- FLOAT_TYPE b0 = 0, b1 = 0;
- b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
- if (!OOB) {
- b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
- }
+ // Check if the second of the pair of elements is OOB, and don't fetch B or
+ // accumulate it. We still fetch a pair of elements for A, which is fine for
+ // quantized formats since they'll be within the same block. We should
+ // probably skip fetching the second element for F16/F32, but as of now we
+ // still do.
+ const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
+
+ FLOAT_TYPE b0 = 0, b1 = 0;
+ b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
+ if (!OOB) {
+ b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
+ }
#endif
- uint ibi = first_row*p.ncols;
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- const uint ib = (ibi + col)/QUANT_K; // block index
- ibi += p.ncols;
+ uint ibi = first_row*p.ncols;
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ const uint ib = (ibi + col)/QUANT_K; // block index
+ ibi += p.ncols;
#if K_PER_ITER == 8
- vec4 v = dequantize4(ib, iqs, a_offset);
- vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
+ vec4 v = dequantize4(ib, iqs, a_offset);
+ vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
- const vec2 dm = get_dm(ib, a_offset);
- if (dm.y != 0) { // quant has min component
- v = v * dm.x + dm.y;
- v2 = v2 * dm.x + dm.y;
- }
+ const vec2 dm = get_dm(ib, a_offset);
+ if (dm.y != 0) { // quant has min component
+ v = v * dm.x + dm.y;
+ v2 = v2 * dm.x + dm.y;
+ }
- // matrix multiplication
- FLOAT_TYPE rowtmp = dot(bv0, v);
- rowtmp += dot(bv1, v2);
+ // matrix multiplication
+ FLOAT_TYPE rowtmp = dot(bv0, v);
+ rowtmp += dot(bv1, v2);
- if (dm.y == 0)
- rowtmp *= dm.x;
+ if (dm.y == 0)
+ rowtmp *= dm.x;
- temp[n] += rowtmp;
+ temp[j][n] += rowtmp;
#else
- const vec2 v = dequantize(ib, iqs, a_offset);
+ const vec2 v = dequantize(ib, iqs, a_offset);
- // matrix multiplication
- temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
- if (!OOB) {
- temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
- }
+ // matrix multiplication
+ temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
+ if (!OOB) {
+ temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
+ }
#endif
+ }
}
}
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
- FLOAT_TYPE temp[NUM_ROWS];
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
- for (uint i = 0; i < NUM_ROWS; ++i) {
- temp[i] = FLOAT_TYPE(0);
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
}
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
i++;
}
- // sum up partial sums and write back result
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] = temp[n];
- }
- barrier();
- [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
- if (tid < s) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] += tmpsh[n][tid + s];
- }
- }
- barrier();
- }
- if (tid == 0) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
- }
- }
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
batch_idx * p.batch_stride_d;
#endif
}
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+layout (constant_id = 1) const uint NUM_ROWS = 1;
+layout (constant_id = 2) const uint NUM_COLS = 1;
+
+shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
+
+void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
+ // sum up partial sums and write back result
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ tmpsh[j][n][tid] = temp[j][n];
+ }
+ }
+ barrier();
+ [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
+ if (tid < s) {
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
+ }
+ }
+ }
+ barrier();
+ }
+ if (tid == 0) {
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint n = 0; n < num_rows; ++n) {
+ data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
+ }
+ }
+ }
+}
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint s_offset = 8*v_im;
const uint y_offset = 128*v_im + l0;
- FLOAT_TYPE temp[NUM_ROWS];
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
- [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
- temp[i] = FLOAT_TYPE(0);
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
- B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
- B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
- B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
- B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
- B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
- B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
- B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
- B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d;
uvec2 qs0 = uvec2(unpack8(qs0_u16));
uvec2 qs16 = uvec2(unpack8(qs16_u16));
- FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
- FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
- [[unroll]] for (int l = 0; l < 2; ++l) {
- sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
- sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
- fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
- fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
- fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
- fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
- fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
- fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
- fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
+ B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
+ B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
+ B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
+ B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
+ B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
+ B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
+ B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
+
+ FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
+ FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
+ [[unroll]] for (int l = 0; l < 2; ++l) {
+ sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
+ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
+ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
+ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
+ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
+ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
+ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
+ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
+ sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
+ fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
+ fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
+ fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
+ fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
+ fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
+ fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
+ fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
+ }
+ temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
- temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
}
}
- // sum up partial sums and write back result
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] = temp[n];
- }
- barrier();
- [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
- if (tid < s) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] += tmpsh[n][tid + s];
- }
- }
- barrier();
- }
- if (tid == 0) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
- }
- }
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0;
- FLOAT_TYPE temp[NUM_ROWS];
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
- [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
- temp[i] = FLOAT_TYPE(0);
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
}
const uint s_shift = 4 * v_im;
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
- B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
- B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
- B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
- B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
- B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
- B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
- B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
- B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
u8vec2 s8 = unpack8(s8_16);
u8vec2 s10 = unpack8(s10_16);
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
- [[unroll]] for (int l = 0; l < 2; ++l) {
- sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
- fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+
+ B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
+ B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
+ B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
+ B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
+ B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
+ B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
+ B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
+ B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
+
+ FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+ [[unroll]] for (int l = 0; l < 2; ++l) {
+ sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
+ fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
+ fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
+ fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
+ fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
+ fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
+ fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
+ fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
+ }
+ temp[j][n] = fma(d, sum, temp[j][n]);
}
- temp[n] = fma(d, sum, temp[n]);
}
}
- // sum up partial sums and write back result
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] = temp[n];
- }
- barrier();
- [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
- if (tid < s) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] += tmpsh[n][tid + s];
- }
- }
- barrier();
- }
- if (tid == 0) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
- }
- }
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;
- FLOAT_TYPE temp[NUM_ROWS];
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
- [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
- temp[i] = FLOAT_TYPE(0);
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
- B_TYPE_VEC4 by10 = data_b_v4[(b_offset + y1_idx) / 4];
- B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8];
- B_TYPE_VEC4 by20 = data_b_v4[(b_offset + y2_idx) / 4];
- B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8];
-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d;
const uint32_t q4_14 = qs64_hi4.z;
const uint32_t q4_15 = qs64_hi4.w;
- const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
- const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
- const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
- const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
- const FLOAT_TYPE smin =
- fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
- fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
- fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
- fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
- temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
- }
- }
-
- // sum up partial sums and write back result
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] = temp[n];
- }
- barrier();
- [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
- if (tid < s) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] += tmpsh[n][tid + s];
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
+ B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8];
+ B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4];
+ B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8];
+
+ const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
+ const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
+ const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
+ const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
+ const FLOAT_TYPE smin =
+ fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
+ fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
+ fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
+ fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
+ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
- barrier();
- }
- if (tid == 0) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
- }
}
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;
- FLOAT_TYPE temp[NUM_ROWS];
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
- [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
- temp[i] = FLOAT_TYPE(0);
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
- B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2];
- B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
- B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
- B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
- B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2];
- B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
- B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
- B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
f16vec2 d = data_a[ib0 + i].d;
const uint32_t q4_14 = qs64_80_hi4.z;
const uint32_t q4_15 = qs64_80_hi4.w;
- const FLOAT_TYPE sx =
- fma(FLOAT_TYPE(by10.x), q4_0,
- fma(FLOAT_TYPE(by10.y), q4_1,
- fma(FLOAT_TYPE(by116.x), q4_2,
- FLOAT_TYPE(by116.y) * q4_3)));
- const FLOAT_TYPE sy =
- fma(FLOAT_TYPE(by132.x), q4_4,
- fma(FLOAT_TYPE(by132.y), q4_5,
- fma(FLOAT_TYPE(by148.x), q4_6,
- FLOAT_TYPE(by148.y) * q4_7)));
- const FLOAT_TYPE sz =
- fma(FLOAT_TYPE(by20.x), q4_8,
- fma(FLOAT_TYPE(by20.y), q4_9,
- fma(FLOAT_TYPE(by216.x), q4_10,
- FLOAT_TYPE(by216.y) * q4_11)));
- const FLOAT_TYPE sw =
- fma(FLOAT_TYPE(by232.x), q4_12,
- fma(FLOAT_TYPE(by232.y), q4_13,
- fma(FLOAT_TYPE(by248.x), q4_14,
- FLOAT_TYPE(by248.y) * q4_15)));
- const FLOAT_TYPE smin =
- fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
- fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
- fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
- (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
- temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
- }
- }
-
- // sum up partial sums and write back result
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] = temp[n];
- }
- barrier();
- [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
- if (tid < s) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] += tmpsh[n][tid + s];
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
+ B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8];
+ B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16];
+ B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24];
+ B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2];
+ B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8];
+ B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16];
+ B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24];
+
+ const FLOAT_TYPE sx =
+ fma(FLOAT_TYPE(by10.x), q4_0,
+ fma(FLOAT_TYPE(by10.y), q4_1,
+ fma(FLOAT_TYPE(by116.x), q4_2,
+ FLOAT_TYPE(by116.y) * q4_3)));
+ const FLOAT_TYPE sy =
+ fma(FLOAT_TYPE(by132.x), q4_4,
+ fma(FLOAT_TYPE(by132.y), q4_5,
+ fma(FLOAT_TYPE(by148.x), q4_6,
+ FLOAT_TYPE(by148.y) * q4_7)));
+ const FLOAT_TYPE sz =
+ fma(FLOAT_TYPE(by20.x), q4_8,
+ fma(FLOAT_TYPE(by20.y), q4_9,
+ fma(FLOAT_TYPE(by216.x), q4_10,
+ FLOAT_TYPE(by216.y) * q4_11)));
+ const FLOAT_TYPE sw =
+ fma(FLOAT_TYPE(by232.x), q4_12,
+ fma(FLOAT_TYPE(by232.y), q4_13,
+ fma(FLOAT_TYPE(by248.x), q4_14,
+ FLOAT_TYPE(by248.y) * q4_15)));
+ const FLOAT_TYPE smin =
+ fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
+ fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
+ fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
+ (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
+ temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
- barrier();
- }
- if (tid == 0) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
- }
}
+
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-layout (constant_id = 0) const uint BLOCK_SIZE = 32;
-layout (constant_id = 1) const uint NUM_ROWS = 1;
-
-shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
-
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint s_offset = 8*v_im + is;
const uint y_offset = 128*v_im + l0;
- FLOAT_TYPE temp[NUM_ROWS];
+ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
- [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
- temp[i] = FLOAT_TYPE(0);
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
+ temp[j][i] = FLOAT_TYPE(0);
+ }
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;
- B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4];
- B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
- B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
- B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
-
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
uvec4 q2 = uvec4(unpack8(q2_u32));
uvec4 q3 = uvec4(unpack8(q3_u32));
- FLOAT_TYPE sum = FLOAT_TYPE(0.0);
- [[unroll]] for (int l = 0; l < 4; ++l) {
- sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
- fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
- fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
- fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
+ [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
+ B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
+ B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
+ B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
+ B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
+
+ FLOAT_TYPE sum = FLOAT_TYPE(0.0);
+ [[unroll]] for (int l = 0; l < 4; ++l) {
+ sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
+ fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
+ fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
+ fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
+ }
+ temp[j][n] += sum * d;
}
- temp[n] += sum * d;
}
}
- // sum up partial sums and write back result
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] = temp[n];
- }
- barrier();
- [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
- if (tid < s) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- tmpsh[n][tid] += tmpsh[n][tid + s];
- }
- }
- barrier();
- }
- if (tid == 0) {
- [[unroll]] for (uint n = 0; n < num_rows; ++n) {
- data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
- }
- }
+ reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
- for (int bs : {1, 512}) {
+ for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1, 1}, {1, 1}));