]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Add ACC_TYPE_VEC2 implementation (llama/16203)
authorSavicStefan <redacted>
Tue, 14 Oct 2025 17:18:05 +0000 (19:18 +0200)
committerGeorgi Gerganov <redacted>
Tue, 14 Oct 2025 19:07:44 +0000 (22:07 +0300)
Signed-off-by: Stefan Savic <redacted>
Co-authored-by: Stefan Savic <redacted>
src/ggml-vulkan/vulkan-shaders/mul_mm.comp

index 85400ac5fc343d8a91966a4f41dc0384ff52fcf2..a20788c4b51e34505dbdcc0a09e9d1cb4d435c6a 100644 (file)
@@ -313,12 +313,12 @@ void main() {
         sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
     }
 #else
-    ACC_TYPE sums[WMITER * TM * WNITER * TN];
+    ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
     FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
-    FLOAT_TYPE_VEC2 cache_b[TN];
+    FLOAT_TYPE_VEC2 cache_b;
 
-    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
-        sums[i] = ACC_TYPE(0.0f);
+    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
+        sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
     }
 #endif
 
@@ -360,20 +360,22 @@ void main() {
                     cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
                 }
             }
-            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
-                [[unroll]] for (uint j = 0; j < TN; j++) {
-                    cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
-                }
 
-                [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
-                    [[unroll]] for (uint cc = 0; cc < TN; cc++) {
-                        [[unroll]] for (uint cr = 0; cr < TM; cr++) {
-                            const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
-                            sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
+            [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
+                [[unroll]] for (uint cc = 0; cc < TN; cc++) {
+                    cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
+
+                    [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
+                        [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
+                            // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
+                            const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
+                            sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr    ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
+                            sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
                         }
                     }
                 }
             }
+
         }
 #endif
 
@@ -388,8 +390,9 @@ void main() {
         }
     }
 #else
-    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
-        sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
+        sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
+        sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
     }
 #endif
 #endif
@@ -463,14 +466,21 @@ void main() {
 
                 const u16vec2 row_idx = row_ids[row_i - ic * BN];
 #endif // MUL_MAT_ID
-                [[unroll]] for (uint cr = 0; cr < TM; cr++) {
+                [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
+                    const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
 #ifdef MUL_MAT_ID
-                    if (dr_warp + cr < p.M) {
-                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+                    if (dr_warp + 2 * cr < p.M) {
+                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
+                    }
+                    if (dr_warp + 2 * cr + 1 < p.M) {
+                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
                     }
 #else
-                    if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
-                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+                    if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
+                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
+                    }
+                    if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
+                        data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
                     }
 #endif // MUL_MAT_ID
                 }