]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : use F32 prec in FA kernels (llama/12688)
authorGeorgi Gerganov <redacted>
Tue, 1 Apr 2025 11:57:19 +0000 (14:57 +0300)
committerGeorgi Gerganov <redacted>
Wed, 2 Apr 2025 12:51:57 +0000 (15:51 +0300)
* metal : use F32 prec in FA kernels

ggml-ci

* cont : fix FA vec kernel

ggml-ci

ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index 3942013f4c90ae12ae6f8907762e43d3d7788609..456e1fd994c4041eaec76fa68e62661023860432 100644 (file)
@@ -4179,7 +4179,7 @@ static void ggml_metal_encode_node(
                     // ne00*(nsg)
                     // each simdgroup has a full f16 head vector in shared mem to accumulate results
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 2*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
                     while (true) {
index 80d0765b4fc0e7d506a7e9824177dab1529f41df..b08666e27991f52fef2a17f12d91492bce644b7f 100644 (file)
@@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        half S[Q] = { [0 ... Q-1] = 0.0f };
-        half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
+        float S[Q] = { [0 ... Q-1] = 0.0f };
+        float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
 
         // thread indices inside the simdgroup
         // TODO: see if we can utilize quad-group functions for better performance
@@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
 
         const bool has_mask = mask != q;
 
-        half slope = 1.0f;
+        float slope = 1.0f;
 
         // ALiBi
         if (args.max_bias > 0.0f) {
             const short h = iq2;
 
-            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const float base = h < args.n_head_log2 ? args.m0 : args.m1;
             const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
             slope = pow(base, exph);
@@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
 
             if (has_mask) {
                 // used to detect blocks full of -INF
-                half smax = -INFINITY;
+                float smax = -INFINITY;
 
                 // load the mask in shared memory
                 #pragma unroll(Q)
                 for (short j = 0; j < Q; ++j) {
                     device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
 
-                    const half m = pm[ic + tiisg];
+                    const float m = pm[ic + tiisg];
 
                     ss[j*TS + C + tiisg] = m;
                     smax = max(smax, m);
@@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
             // online softmax
             {
                 for (ushort j = 0; j < Q; ++j) {
-                    const half m = M[j];
+                    const float m = M[j];
 
                     // scale and apply the logitcap / mask
-                    half s = ss[j*TS + tiisg]*args.scale;
+                    float s = ss[j*TS + tiisg]*args.scale;
 
                     if (args.logit_softcap != 0.0f) {
                         s = args.logit_softcap*precise::tanh(s);
@@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
 
                     M[j] = simd_max(max(M[j], s));
 
-                    const half ms = exp(m - M[j]);
-                    const half vs = exp(s - M[j]);
+                    const float ms = exp(m - M[j]);
+                    const float vs = exp(s - M[j]);
 
                     S[j] = S[j]*ms + simd_sum(vs);
 
@@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
 
     // reduce the warps sequentially
     for (ushort sg = 1; sg < nsg; ++sg) {
-        half S = { 0.0f };
-        half M = { -__FLT16_MAX__/2 };
+        float S = { 0.0f };
+        float M = { -__FLT16_MAX__/2 };
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
@@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
         // the first simdgroup accumulates the results from the other simdgroups
         if (sgitg == 0) {
             for (short j = 0; j < Q; ++j) {
-                const half S0 = ss[j*TS +         0];
-                const half S1 = ss[j*TS + sg*SH + 0];
+                const float S0 = ss[j*TS +         0];
+                const float S1 = ss[j*TS + sg*SH + 0];
 
-                const half M0 = ss[j*TS +         1];
-                const half M1 = ss[j*TS + sg*SH + 1];
+                const float M0 = ss[j*TS +         1];
+                const float M1 = ss[j*TS + sg*SH + 1];
 
                 M = max(M0, M1);
 
-                const half ms0 = exp(M0 - M);
-                const half ms1 = exp(M1 - M);
+                const float ms0 = exp(M0 - M);
+                const float ms1 = exp(M1 - M);
 
                 S = S0*ms0 + S1*ms1;
 
@@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
     constexpr short DV4 = DV/4;
     constexpr short NW  = N_SIMDWIDTH;
     constexpr short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
-    constexpr short SH  = 2*C;   // shared memory per simdgroup
+    constexpr short SH  = 4*C;   // shared memory per simdgroup
 
     const short T = DK + nsg*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +                0*DK); // holds the query data
-    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +                0*DK); // same as above but in q4_t
-    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + sgitg*SH     + Q*DK); // scratch buffer for attention
-    threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH     + Q*DK); // same as above but in s4_t
-    threadgroup half * sm  = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
-    threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV     + Q*T);  // scratch buffer for the results
+  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                  0*DK); // holds the query data
+    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                  0*DK); // same as above but in q4_t
+    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 + sgitg*SH       + Q*DK); // scratch buffer for attention
+    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 + sgitg*SH       + Q*DK); // same as above but in s4_t
+    threadgroup float * sm  = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
+    threadgroup o4_t  * sr4 = (threadgroup o4_t  *) (shmem_f16 + sgitg*DV       + Q*T);  // scratch buffer for the results
 
     // store the result for all queries in local memory (the O matrix from the paper)
     o4_t lo[DV4/NL];
@@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     {
-        half S = 0.0f;
-        half M = -__FLT16_MAX__/2;
+        float S = 0.0f;
+        float M = -__FLT16_MAX__/2;
 
         // thread indices inside the simdgroup
         const short tx = tiisg%NL;
@@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
         // pointer to the mask
         device const half * pm = (device const half *) (mask + iq1*args.nb31);
 
-        half slope = 1.0f;
+        float slope = 1.0f;
 
         // ALiBi
         if (args.max_bias > 0.0f) {
             const short h = iq2;
 
-            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const float base = h < args.n_head_log2 ? args.m0 : args.m1;
             const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
             slope = pow(base, exph);
@@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
 
             // online softmax
             {
-                const half m = M;
-                const half s = ss[tiisg];
+                const float m = M;
+                const float s = ss[tiisg];
 
                 M = simd_max(max(M, s));
 
-                const half ms = exp(m - M);
-                const half vs = exp(s - M);
+                const float ms = exp(m - M);
+                const float vs = exp(s - M);
 
                 S = S*ms + simd_sum(vs);
 
@@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
                         v4_t mv;
                         deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
 
-                        lo[ii/NL] += mv*ms;
+                        lo[ii/NL] += o4_t(float4(mv)*float4(ms));
                     }
                 }
             }
@@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
     // parallel reduce
     for (short r = nsg/2; r > 0; r >>= 1) {
         if (sgitg < r) {
-            const half S0 = ss[       0];
-            const half S1 = ss[r*SH + 0];
+            const float S0 = ss[           0];
+            const float S1 = ss[r*(SH/2) + 0];
 
-            const half M0 = ss[       1];
-            const half M1 = ss[r*SH + 1];
+            const float M0 = ss[           1];
+            const float M1 = ss[r*(SH/2) + 1];
 
-            const half M = max(M0, M1);
+            const float M = max(M0, M1);
 
-            const half ms0 = exp(M0 - M);
-            const half ms1 = exp(M1 - M);
+            const float ms0 = exp(M0 - M);
+            const float ms1 = exp(M1 - M);
 
-            const half S = S0*ms0 + S1*ms1;
+            const float S = S0*ms0 + S1*ms1;
 
             if (tiisg == 0) {
                 ss[0] = S;
@@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
 //       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
 //
 #define FA_TYPES \
-           half4, \
-           half4, \
-           half4, \
-    float,        \
-    half,  half4, \
+           half4,  \
+           half4,  \
+           half4,  \
+    float,         \
+    float, float4, \
            half4
 
 typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;