]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Use larger loads in scalar/coopmat1 matmul (llama/15729)
authorJeff Bolz <redacted>
Sun, 7 Sep 2025 16:53:07 +0000 (11:53 -0500)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
I think glslang will translate an access like x[i][1].z to
OpAccessChain ... x, i, 1, 2
OpLoad float16_t ...

rather than loading all of x[i] in a single OpLoad. Change the
code to explicitly load the vector/matrix.

src/ggml-vulkan/vulkan-shaders/mul_mm.comp
src/ggml-vulkan/vulkan-shaders/types.comp
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 7e10e99e9e8771b3ec6f56f309baef1dbb6a2e67..f6a7761ffa03e05b0b3df459757b47085185ecf7 100644 (file)
@@ -315,21 +315,23 @@ void main() {
 #if LOAD_VEC_A == 8
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
-            buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx][0].x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
-            buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
-            buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
-            buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
-            buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
-            buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
-            buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
+            A_TYPE32 aa = A_TYPE32(data_a[idx]);
+            buf_a[buf_idx    ] = FLOAT_TYPE(aa[0].x);
+            buf_a[buf_idx + 1] = FLOAT_TYPE(aa[0].y);
+            buf_a[buf_idx + 2] = FLOAT_TYPE(aa[0].z);
+            buf_a[buf_idx + 3] = FLOAT_TYPE(aa[0].w);
+            buf_a[buf_idx + 4] = FLOAT_TYPE(aa[1].x);
+            buf_a[buf_idx + 5] = FLOAT_TYPE(aa[1].y);
+            buf_a[buf_idx + 6] = FLOAT_TYPE(aa[1].z);
+            buf_a[buf_idx + 7] = FLOAT_TYPE(aa[1].w);
 #elif LOAD_VEC_A == 4
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
-            buf_a[buf_idx    ] = FLOAT_TYPE(data_a[idx].x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
-            buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
-            buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
+            A_TYPE32 aa = A_TYPE32(data_a[idx]);
+            buf_a[buf_idx    ] = FLOAT_TYPE(aa.x);
+            buf_a[buf_idx + 1] = FLOAT_TYPE(aa.y);
+            buf_a[buf_idx + 2] = FLOAT_TYPE(aa.z);
+            buf_a[buf_idx + 3] = FLOAT_TYPE(aa.w);
 #else
             if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
                 buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
@@ -808,14 +810,19 @@ void main() {
             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
 #endif
             const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
-            buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
-            buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
-            buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
-            buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
-            buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
-            buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
-            buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
-            buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
+#if defined(DATA_B_BF16)
+            B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
+#else
+            B_TYPE32 bb = B_TYPE32(data_b[idx]);
+#endif
+            buf_b[buf_idx + 0] = FLOAT_TYPE(bb[0].x);
+            buf_b[buf_idx + 1] = FLOAT_TYPE(bb[0].y);
+            buf_b[buf_idx + 2] = FLOAT_TYPE(bb[0].z);
+            buf_b[buf_idx + 3] = FLOAT_TYPE(bb[0].w);
+            buf_b[buf_idx + 4] = FLOAT_TYPE(bb[1].x);
+            buf_b[buf_idx + 5] = FLOAT_TYPE(bb[1].y);
+            buf_b[buf_idx + 6] = FLOAT_TYPE(bb[1].z);
+            buf_b[buf_idx + 7] = FLOAT_TYPE(bb[1].w);
 #elif LOAD_VEC_B == 4
 #ifdef MUL_MAT_ID
             const u16vec2 row_idx = row_ids[loadc_b + l];
@@ -824,10 +831,15 @@ void main() {
             const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
 #endif
             const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
-            buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
-            buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
-            buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
-            buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
+#if defined(DATA_B_BF16)
+            B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
+#else
+            B_TYPE32 bb = B_TYPE32(data_b[idx]);
+#endif
+            buf_b[buf_idx + 0] = FLOAT_TYPE(bb.x);
+            buf_b[buf_idx + 1] = FLOAT_TYPE(bb.y);
+            buf_b[buf_idx + 2] = FLOAT_TYPE(bb.z);
+            buf_b[buf_idx + 3] = FLOAT_TYPE(bb.w);
 #elif !MUL_MAT_ID
             if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
index 408722c87885aac17d57c6e662fc7a07c046b961..c2acc803f68e91fd23d21ed0ae0ca964615373cb 100644 (file)
 
 #if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
 #define A_TYPE float
+#define A_TYPE32 float
 #elif LOAD_VEC_A == 4
 #define A_TYPE vec4
+#define A_TYPE32 vec4
 #elif LOAD_VEC_A == 8
 #define A_TYPE mat2x4
+#define A_TYPE32 mat2x4
 #endif
 #endif
 
 
 #if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
 #define A_TYPE float16_t
+#define A_TYPE32 float
 #elif LOAD_VEC_A == 4
 #define A_TYPE f16vec4
+#define A_TYPE32 vec4
 #elif LOAD_VEC_A == 8
 #define A_TYPE f16mat2x4
+#define A_TYPE32 mat2x4
 #endif
 #endif
 
@@ -1424,6 +1430,11 @@ float bf16_to_fp32(uint32_t u)
     return uintBitsToFloat(u << 16);
 }
 
+vec4 bf16_to_fp32(uvec4 u)
+{
+    return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w));
+}
+
 float e8m0_to_fp32(uint8_t x) {
     uint32_t bits;
 
index 613498d0d50b703038adf0c0f1d7b7a192821168..93cdfd09a946f592a0148f419a4ac12c27dabf91 100644 (file)
@@ -364,11 +364,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
     };
 
     // Shaders with f16 B_TYPE
-    string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
-    string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f32_f16",         source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"},                                                     {"B_TYPE", "float16_t"},                                          {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
 
-    string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
-    string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f16_aligned",     source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+    string_to_spv(shader_name + "_f16",             source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"},                                                     {"B_TYPE", "float16_t"},                                          {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
 
     // bf16
     {
@@ -384,8 +384,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
         if (!(coopmat || coopmat2))
 #endif
         {
-            string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"},   {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
-            string_to_spv(shader_name + "_bf16",         source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                      {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"},                          {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}),                   fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"},   {"B_TYPE32", "vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_bf16",         source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                      {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"},                        {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}),                   fp16, coopmat, coopmat2, f16acc);
         }
     }
 
@@ -408,13 +408,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
 
         // don't generate f32 variants for coopmat2
         if (!coopmat2) {
-            string_to_spv(shader_name + "_" + tname + "_f32",         source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float"},            {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
-            string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f32",         source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float"},                                              {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
         }
 
         if (tname != "f16" && tname != "f32") {
-            string_to_spv(shader_name + "_" + tname + "_f16",         source_name,  merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
-            string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f16",         source_name,  merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},                                          {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
         }
 
 #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)