]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : improve clarity (minor) (llama/10171)
authorGeorgi Gerganov <redacted>
Fri, 8 Nov 2024 16:37:41 +0000 (18:37 +0200)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
ggml/src/ggml-metal.metal

index edce741088f05d1ef4875940021cc0b8f0eaad29..89f12724d3095eab21e2890e412bd71730e1741a 100644 (file)
@@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
     const short D4  = D/4;
     const short D16 = D/16;
     const short NW  = N_SIMDWIDTH;
-    const short NW4 = NW/4;
+    const short N = NW/4;
     const short SH  = 2*C; // shared memory per simdgroup
 
     const short T = D + nsg*SH; // shared memory size per query in (half)
@@ -3370,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec(
     threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D      + Q*T); // scratch buffer for the results
 
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
-    o4x4_t lo[D16/NW4];
+    o4x4_t lo[D16/NL];
 
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3384,7 +3384,7 @@ kernel void kernel_flash_attn_ext_vec(
     }
 
     // zero out lo
-    for (short i = 0; i < D16/NW4; i += NW4) {
+    for (short i = 0; i < D16/NL; ++i) {
         lo[i] = (o4x4_t) 0.0f;
     }
 
@@ -3400,8 +3400,8 @@ kernel void kernel_flash_attn_ext_vec(
         half M = -__FLT16_MAX__/2;
 
         // thread indices inside the simdgroup
-        const short tx = tiisg%8;
-        const short ty = tiisg/8;
+        const short tx = tiisg%NL;
+        const short ty = tiisg/NL;
 
         // broadcast kv
         //const short rk2 = ne02/ne12;
@@ -3411,10 +3411,10 @@ kernel void kernel_flash_attn_ext_vec(
         const short ikv3 = iq3/(ne03/ne_12_3);
 
         // load the queries from shared memory into local memory
-        q4x4_t mq[D16/NW4];
+        q4x4_t mq[D16/NL];
 
-        for (short ii = 0; ii < D16; ii += NW4) {
-            mq[ii/NW4] = sq4x4[ii + tx];
+        for (short ii = 0; ii < D16; ii += NL) {
+            mq[ii/NL] = sq4x4[ii + tx];
         }
 
         const bool has_mask = mask != q;
@@ -3455,17 +3455,17 @@ kernel void kernel_flash_attn_ext_vec(
                     device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
 
 #pragma unroll
-                    for (short ii = 0; ii < D16; ii += NW4) {
+                    for (short ii = 0; ii < D16; ii += NL) {
                         const short i = ii + tx;
 
                         k4x4_t mk;
                         deq_k(pk + i/nl_k, i%nl_k, mk);
 
                         mqk +=
-                            dot(mq[ii/NW4][0], mk[0]) +
-                            dot(mq[ii/NW4][1], mk[1]) +
-                            dot(mq[ii/NW4][2], mk[2]) +
-                            dot(mq[ii/NW4][3], mk[3]);
+                            dot(mq[ii/NL][0], mk[0]) +
+                            dot(mq[ii/NL][1], mk[1]) +
+                            dot(mq[ii/NL][2], mk[2]) +
+                            dot(mq[ii/NL][3], mk[3]);
                     }
 
                     // simdgroup reduce
@@ -3513,8 +3513,8 @@ kernel void kernel_flash_attn_ext_vec(
 
                 // O = diag(ms)*O
 #pragma unroll
-                for (short ii = 0; ii < D16; ii += NW4) {
-                    lo[ii/NW4] *= ms;
+                for (short ii = 0; ii < D16; ii += NL) {
+                    lo[ii/NL] *= ms;
                 }
             }
 
@@ -3529,13 +3529,13 @@ kernel void kernel_flash_attn_ext_vec(
                     const s4x4_t ms(ss[4*cc + ty]);
 
 #pragma unroll
-                    for (short ii = 0; ii < D16; ii += NW4) {
+                    for (short ii = 0; ii < D16; ii += NL) {
                         const short i = ii + tx;
 
                         v4x4_t mv;
                         deq_v(pv4 + i/nl_v, i%nl_v, mv);
 
-                        lo[ii/NW4] += mv*ms;
+                        lo[ii/NL] += mv*ms;
                     }
                 }
             }
@@ -3557,23 +3557,37 @@ kernel void kernel_flash_attn_ext_vec(
     // [ 5, 13, 21, 29] -> [ 5]
     // [ 6, 14, 22, 30] -> [ 6]
     // [ 7, 15, 23, 31] -> [ 7]
-    for (short ii = 0; ii < D16; ii += NW4) {
-        lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
-        lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0],  8);
-
-        lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
-        lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1],  8);
-
-        lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
-        lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2],  8);
-
-        lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
-        lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3],  8);
+    for (short ii = 0; ii < D16; ii += NL) {
+        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
+        lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  8);
+      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  4);
+      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  2);
+      //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0],  1);
+
+        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
+        lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  8);
+      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  4);
+      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  2);
+      //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1],  1);
+
+        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
+        lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  8);
+      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  4);
+      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  2);
+      //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2],  1);
+
+        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
+        lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  8);
+      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  4);
+      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  2);
+      //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3],  1);
     }
 
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
     // store results to shared memory
-    for (short i = tiisg; i < D16; i += NW4) {
-        sr4x4[i] = lo[i/NW4];
+    for (short i = tiisg; i < D16; i += NL) {
+        sr4x4[i] = lo[i/NL];
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);