]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : use less stack memory in FA kernel (llama/14088)
authorGeorgi Gerganov <redacted>
Mon, 9 Jun 2025 20:05:02 +0000 (23:05 +0300)
committerGeorgi Gerganov <redacted>
Tue, 10 Jun 2025 09:40:33 +0000 (12:40 +0300)
* metal : use less stack memory in FA kernel

ggml-ci

* cont : fix BF16 variant

ggml/src/ggml-metal/ggml-metal.metal

index 58763e39e835312729ec308e6c868217f1f7e3ba..5d7760217f82644602fdb635f52943a8b9f89d56 100644 (file)
@@ -3333,8 +3333,6 @@ kernel void kernel_flash_attn_ext(
 
     threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +                0*DK); // holds the query data
     threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +                0*DK); // same as above but in q4_t
-    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +                0*DK); // reuse query data for accumulation
-    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +                0*DK); // same as above but in o4_t
     threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
 
     threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
@@ -3548,20 +3546,20 @@ kernel void kernel_flash_attn_ext(
 
             // O = diag(ms)*O
             {
-                s8x8_t mm;
-                simdgroup_load(mm, ss + 2*C, TS, 0, false);
+                s8x8_t ms;
+                simdgroup_load(ms, ss + 2*C, TS, 0, false);
 
                 #pragma unroll(DV8)
                 for (short i = 0; i < DV8; ++i) {
-                    simdgroup_multiply(lo[i], mm, lo[i]);
+                    simdgroup_multiply(lo[i], ms, lo[i]);
                 }
             }
 
             // O = O + (Q*K^T)*V
             {
                 for (short cc = 0; cc < C/8; ++cc) {
-                    s8x8_t ms;
-                    simdgroup_load(ms, ss + 8*cc, TS, 0, false);
+                    s8x8_t vs;
+                    simdgroup_load(vs, ss + 8*cc, TS, 0, false);
 
                     if (is_same<vd4x4_t, v4x4_t>::value) {
                         // we can read directly from global memory
@@ -3572,7 +3570,7 @@ kernel void kernel_flash_attn_ext(
                             v8x8_t mv;
                             simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
 
-                            simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
+                            simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
                         }
                     } else {
                         for (short ii = 0; ii < DV16; ii += 4) {
@@ -3593,10 +3591,10 @@ kernel void kernel_flash_attn_ext(
                                     v8x8_t mv;
 
                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
 
                                     simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
                                 }
                             } else {
                                 if (ii + tx < DV16) {
@@ -3611,10 +3609,10 @@ kernel void kernel_flash_attn_ext(
                                     v8x8_t mv;
 
                                     simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
 
                                     simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
-                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
                                 }
                             }
                         }
@@ -3624,83 +3622,80 @@ kernel void kernel_flash_attn_ext(
         }
 
         // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
-        for (short j = 0; j < Q; ++j) {
-            if (tiisg == 0) {
-                ss[j*TS + 0] = S[j];
-                ss[j*TS + 1] = M[j];
-            }
+        for (short j = tiisg; j < Q; j += NW) {
+            ss[j*TS + 0] = S[j];
+            ss[j*TS + 1] = M[j];
         }
     }
 
-    // reduce the warps sequentially
-    for (ushort sg = 1; sg < nsg; ++sg) {
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        // each simdgroup stores its output to shared memory, reusing sq
-        if (sgitg == sg) {
-            for (short i = 0; i < DV8; ++i) {
-                simdgroup_store(lo[i], so + i*8, DV, 0, false);
-            }
+    threadgroup float  * so  = (threadgroup float  *) (shmem_f16 + 0*DK); // reuse query data for accumulation
+    threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
+
+    // store result to shared memory in F32
+    if (sgitg == 0) {
+        for (short i = 0; i < DV8; ++i) {
+            //simdgroup_store(lo[i], so + i*8, DV, 0, false);
+            simdgroup_float8x8 t(1.0f);
+            simdgroup_multiply(t, lo[i], t);
+            simdgroup_store(t, so + i*8, DV, 0, false);
         }
+    }
 
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        // the first simdgroup accumulates the results from the other simdgroups
-        if (sgitg == 0) {
-            for (short j = 0; j < Q; ++j) {
-                const float S0 = ss[j*TS +         0];
-                const float S1 = ss[j*TS + sg*SH + 0];
+    // reduce the warps sequentially
+    for (ushort sg = 1; sg < nsg; ++sg) {
+        if (sgitg == sg) {
+            for (short j = tiisg; j < Q; j += NW) {
+                const float S0 = ss[j*TS - 1*SH + 0];
+                const float S1 = ss[j*TS        + 0];
 
-                const float M0 = ss[j*TS +         1];
-                const float M1 = ss[j*TS + sg*SH + 1];
+                const float M0 = ss[j*TS - 1*SH + 1];
+                const float M1 = ss[j*TS        + 1];
 
                 const float M = max(M0, M1);
 
-                const float ms0 = exp(M0 - M);
-                const float ms1 = exp(M1 - M);
+                float ms0 = exp(M0 - M);
+                float ms1 = exp(M1 - M);
 
                 const float S = S0*ms0 + S1*ms1;
 
-                if (tiisg == 0) {
-                    ss[j*TS + 0] = S;
-                    ss[j*TS + 1] = M;
+                ss[j*TS + 0] = S;
+                ss[j*TS + 1] = M;
 
-                    ss[j*TS + 2*C + j        ] = ms0;
-                    ss[j*TS + 2*C + j + sg*SH] = ms1;
-                }
+                ss[j*TS + 2*C + j - 1*SH] = ms0;
+                ss[j*TS + 2*C + j       ] = ms1;
             }
 
+            //simdgroup_barrier(mem_flags::mem_threadgroup);
+
             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
             {
                 s8x8_t ms0;
                 s8x8_t ms1;
 
-                simdgroup_load(ms0, ss + 2*C,         TS, 0, false);
-                simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
+                simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
+                simdgroup_load(ms1, ss + 2*C,        TS, 0, false);
 
                 #pragma unroll(DV8)
                 for (short i = 0; i < DV8; ++i) {
-                    o8x8_t t;
+                    simdgroup_float8x8 t;
 
                     simdgroup_load    (t, so + i*8, DV, 0, false);
-                    simdgroup_multiply(t, ms1, t);
+                    simdgroup_multiply(t, ms0, t);
 
-                    simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+                    simdgroup_multiply_accumulate(t, ms1, lo[i], t);
+                    simdgroup_store(t, so + i*8, DV, 0, false);
                 }
             }
         }
-    }
 
-    // store result to shared memory (reuse sq)
-    if (sgitg == 0) {
-        for (short i = 0; i < DV8; ++i) {
-            simdgroup_store(lo[i], so + i*8, DV, 0, false);
-        }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-
-    threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
+    threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
 
     // final rescale with 1/S and store to global memory
     for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
@@ -3723,8 +3718,8 @@ kernel void kernel_flash_attn_ext(
     half,   half4x4,   simdgroup_half8x8,  \
     float,             simdgroup_float8x8, \
     float,             simdgroup_float8x8, \
-    float,  float4,    simdgroup_float8x8
-    //half,   half4,     simdgroup_half8x8
+    half,   half4,     simdgroup_half8x8
+    //float,  float4,    simdgroup_float8x8
 
 #define FA_TYPES_BF \
     bfloat, bfloat4,   simdgroup_bfloat8x8, \
@@ -3732,8 +3727,8 @@ kernel void kernel_flash_attn_ext(
     bfloat, bfloat4x4, simdgroup_bfloat8x8, \
     float,             simdgroup_float8x8,  \
     float,             simdgroup_float8x8,  \
-    float,  float4,    simdgroup_float8x8
-    //half,   half4,     simdgroup_half8x8
+    half,   half4,     simdgroup_half8x8
+    //float,  float4,    simdgroup_float8x8
 
 typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;