]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Increase BK to 32; use BK/4 for non-CM mul_mm.comp (llama/16636)
authorSavicStefan <redacted>
Sat, 8 Nov 2025 08:28:22 +0000 (09:28 +0100)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 16:30:22 +0000 (18:30 +0200)
Signed-off-by: Stefan Savic <redacted>
Co-authored-by: Stefan Savic <redacted>
src/ggml-vulkan/vulkan-shaders/mul_mm.comp

index d260969f07e882137def1b1e9a227849f197d88e..5c5251da39bd1420724672534964196aeda8f976 100644 (file)
@@ -100,7 +100,6 @@ layout (push_constant) uniform parameter
 layout (constant_id = 0) const uint BLOCK_SIZE = 64;
 layout (constant_id = 1) const uint BM = 64;
 layout (constant_id = 2) const uint BN = 64;
-layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant
 layout (constant_id = 4) const uint WM = 32;
 layout (constant_id = 5) const uint WN = 32;
 layout (constant_id = 6) const uint WMITER = 2;
@@ -109,6 +108,14 @@ layout (constant_id = 8) const uint TN = 2;
 layout (constant_id = 9) const uint TK = 1;  // Only needed for coopmat
 layout (constant_id = 10) const uint WARP = 32;
 
+#if defined(DATA_A_F32) || defined(DATA_A_F16)
+#define BK 32
+#define BK_STEP 4
+#else
+layout (constant_id = 3) const uint BK = 16;  // Assumed to be 32 if working with a quant
+#define BK_STEP 2
+#endif
+
 #ifdef COOPMAT
 #define SHMEM_STRIDE (BK / 2 + 4)
 #else
@@ -244,8 +251,13 @@ void main() {
     }
 #else
     ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
+#if defined(DATA_A_F32) || defined(DATA_A_F16)
+    FLOAT_TYPE_VEC4 cache_a[WMITER * TM];
+    FLOAT_TYPE_VEC4 cache_b;
+#else
     FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
     FLOAT_TYPE_VEC2 cache_b;
+#endif
 
     [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
         sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
@@ -283,24 +295,41 @@ void main() {
             }
         }
 #else
-        [[unroll]] for (uint i = 0; i < BK / 2; i++) {
+        [[unroll]] for (uint i = 0; i < BK / BK_STEP; i++) {
             // Load from shared into cache
             [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
                 [[unroll]] for (uint j = 0; j < TM; j++) {
+                #if defined(DATA_A_F32) || defined(DATA_A_F16)
+                    cache_a[wsir * TM + j].xy = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i    ];
+                    cache_a[wsir * TM + j].zw = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + 2 * i + 1];
+                #else
                     cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
+                #endif
                 }
             }
 
             [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
                 [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+                #if defined(DATA_A_F32) || defined(DATA_A_F16)
+                    cache_b.xy = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i    ];
+                    cache_b.zw = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + 2 * i + 1];
+                #else
                     cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
+                #endif
 
                     [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
                         [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
                             // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
                             const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
+                        #if defined(DATA_A_F32) || defined(DATA_A_F16)
+                            sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].y), ACC_TYPE(cache_b.y),
+                                               fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x))));
+                            sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y),
+                                               fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y))));
+                        #else
                             sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
                             sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
+                        #endif
                         }
                     }
                 }