]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: vec dot matrix multiplication fix (#16151)
authorRuben Ortlam <redacted>
Mon, 22 Sep 2025 05:22:43 +0000 (07:22 +0200)
committerGitHub <redacted>
Mon, 22 Sep 2025 05:22:43 +0000 (07:22 +0200)
* vulkan: fix matrix multiplication index calculation for odd m/n and odd k in combination with batching

* add odd m/n + odd k test with batching

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp

index 38a4d07d0300f624c4e0347178c19839c6bf73d3..3cb24412d577ee67b0d4a3fb7594b986654cf034 100644 (file)
 #include "types.comp"
 
 #ifndef LOAD_VEC_A
-#define LOAD_VEC_A 2
+#define LOAD_VEC_A 1
 #endif
 #ifndef LOAD_VEC_B
-#define LOAD_VEC_B 2
+#define LOAD_VEC_B 1
+#endif
+
+// Load 2 values at once without affecting index calculations through LOAD_VEC
+#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
+#define LOAD_VEC_BATCH_A 2
+#else
+#define LOAD_VEC_BATCH_A 1
+#endif
+#if !defined(ALIGNED)
+#define LOAD_VEC_BATCH_B 2
+#else
+#define LOAD_VEC_BATCH_B 1
 #endif
 
 #if !defined(TO_FLOAT_TYPE)
@@ -236,13 +248,13 @@ void main() {
     const uint warp_r = warp_i % (BM / WM);
     const uint warp_c = warp_i / (BM / WM);
 
-    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
-    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
-    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
-    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
+    const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
+    const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
+    const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
+    const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
 
-    const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
-    const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
+    const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
+    const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
 
 #ifdef MUL_MAT_ID
 #ifdef MUL_MAT_ID_USE_SUBGROUPS
index 69d0e64c35d9ff46b690a421136f24b1f03703b2..0ebfbd6462c8b76897ddc59b597cb6fdc12f22b1 100644 (file)
@@ -14,8 +14,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
             buf_a[buf_idx    ] = aa.xy;
             buf_a[buf_idx + 1] = aa.zw;
-#else // LOAD_VEC_A == 2
-            const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
+#else // LOAD_VEC_BATCH_A == 2
+            const uint idx = pos_a + col * p.stride_a + row * 2;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             if (idx_m < p.M && block + row * 2 + 1 < end_k) {
                 buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
@@ -33,8 +33,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
             FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
             buf_a[buf_idx    ] = aa.xy;
             buf_a[buf_idx + 1] = aa.zw;
-#else // LOAD_VEC_A == 2
-            const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
+#else // LOAD_VEC_BATCH_A == 2
+            const uint idx = pos_a + col * p.stride_a + row * 2;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             if (idx_m < p.M && block + row * 2 + 1 < end_k) {
                 buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
@@ -500,8 +500,8 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
 #endif
             buf_b[buf_idx + 0] = bb.xy;
             buf_b[buf_idx + 1] = bb.zw;
-#else // LOAD_VEC_B == 2
-            const uint idx = pos_b * 2 + col * p.stride_b + row * 2;
+#else // LOAD_VEC_BATCH_B == 2
+            const uint idx = pos_b + col * p.stride_b + row * 2;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             if (idx_n < p.N && block + row * 2 + 1 < end_k) {
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
@@ -536,17 +536,17 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
 #endif
             buf_b[buf_idx + 0] = bb.xy;
             buf_b[buf_idx + 1] = bb.zw;
-#else // LOAD_VEC_B == 2
+#else // LOAD_VEC_BATCH_B == 2
             const uint row_i = ic * BN + col;
             const uint buf_idx = col * SHMEM_STRIDE + row;
             if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
                 const u16vec2 row_idx = row_ids[col];
-                const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
+                const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
                                                  TO_FLOAT_TYPE(data_b[idx + 1]));
             } else if (row_i < _ne1 && block + row * 2 < end_k) {
                 const u16vec2 row_idx = row_ids[col];
-                const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
+                const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
             } else {
                 buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
index 74a4794d34f9fc81fb9c3b8a24a2257e1c9a8bf8..8e2507ad8eec5cdd17e38a6ddc937dcea04ed0f9 100644 (file)
@@ -454,7 +454,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
 
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);
         // For unaligned, load one at a time for f32/f16, or two at a time for quants
-        std::string load_vec_a_unaligned = coopmat2 ? "1" : (tname == "f32" || tname == "f16" || tname == "bf16") ? "2" : load_vec_quant;
+        std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
         // For aligned matmul loads
         std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
 
index 507b691dc96e204a4975e47a9f7a029724f77675..ef6594ea5902b22c6794d3db1026e69ad09d2814 100644 (file)
@@ -6231,6 +6231,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67,  {1,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1,  1}, {1, 1}, {0, 1, 2, 3}, true, 3));
+    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
 
     for (auto bs2 : {1,3}) {
         for (auto bs : {1,2,4,8}) {