#if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
- buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
- buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
- buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
- buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
- buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
- buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
+ A_TYPE32 aa = A_TYPE32(data_a[idx]);
+ buf_a[buf_idx ] = FLOAT_TYPE(aa[0].x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(aa[0].y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE(aa[0].z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE(aa[0].w);
+ buf_a[buf_idx + 4] = FLOAT_TYPE(aa[1].x);
+ buf_a[buf_idx + 5] = FLOAT_TYPE(aa[1].y);
+ buf_a[buf_idx + 6] = FLOAT_TYPE(aa[1].z);
+ buf_a[buf_idx + 7] = FLOAT_TYPE(aa[1].w);
#elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
- buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
- buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
- buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
- buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
+ A_TYPE32 aa = A_TYPE32(data_a[idx]);
+ buf_a[buf_idx ] = FLOAT_TYPE(aa.x);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(aa.y);
+ buf_a[buf_idx + 2] = FLOAT_TYPE(aa.z);
+ buf_a[buf_idx + 3] = FLOAT_TYPE(aa.w);
#else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
- buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
- buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
- buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
- buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
- buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
- buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
- buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
- buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
+#if defined(DATA_B_BF16)
+ B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
+#else
+ B_TYPE32 bb = B_TYPE32(data_b[idx]);
+#endif
+ buf_b[buf_idx + 0] = FLOAT_TYPE(bb[0].x);
+ buf_b[buf_idx + 1] = FLOAT_TYPE(bb[0].y);
+ buf_b[buf_idx + 2] = FLOAT_TYPE(bb[0].z);
+ buf_b[buf_idx + 3] = FLOAT_TYPE(bb[0].w);
+ buf_b[buf_idx + 4] = FLOAT_TYPE(bb[1].x);
+ buf_b[buf_idx + 5] = FLOAT_TYPE(bb[1].y);
+ buf_b[buf_idx + 6] = FLOAT_TYPE(bb[1].z);
+ buf_b[buf_idx + 7] = FLOAT_TYPE(bb[1].w);
#elif LOAD_VEC_B == 4
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
- buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
- buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
- buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
- buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
+#if defined(DATA_B_BF16)
+ B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
+#else
+ B_TYPE32 bb = B_TYPE32(data_b[idx]);
+#endif
+ buf_b[buf_idx + 0] = FLOAT_TYPE(bb.x);
+ buf_b[buf_idx + 1] = FLOAT_TYPE(bb.y);
+ buf_b[buf_idx + 2] = FLOAT_TYPE(bb.z);
+ buf_b[buf_idx + 3] = FLOAT_TYPE(bb.w);
#elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
};
// Shaders with f16 B_TYPE
- string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
- string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
- string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
- string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
if (!(coopmat || coopmat2))
#endif
{
- string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
- string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE32", "vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
// don't generate f32 variants for coopmat2
if (!coopmat2) {
- string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
- string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
- string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
- string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+ string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)