]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : use F32 accumulators in FA kernels (llama/13975)
authorGeorgi Gerganov <redacted>
Mon, 2 Jun 2025 18:33:40 +0000 (21:33 +0300)
committerGeorgi Gerganov <redacted>
Tue, 10 Jun 2025 09:40:33 +0000 (12:40 +0300)
ggml-ci

ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index f78e7eee553b65dd90fa21b52f48f516192157e8..bc93bc633a49b224d38d3485053c205ba2bbdec3 100644 (file)
@@ -4766,6 +4766,8 @@ static bool ggml_metal_encode_node(
                     GGML_ASSERT(nqptg  % 8  == 0);
                     GGML_ASSERT(ncpsg  % 32 == 0);
 
+                    const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
+
                     // 2*(2*ncpsg + nqptg)*(nsg)
                     // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
                     //
@@ -4773,7 +4775,7 @@ static bool ggml_metal_encode_node(
                     // the shared memory needed for the simdgroups to load the KV cache
                     // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
 
@@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node(
                     // and store the soft_max values and the mask
                     //
                     // ne00*(nsg)
-                    // each simdgroup has a full f16 head vector in shared mem to accumulate results
+                    // each simdgroup has a full f32 head vector in shared mem to accumulate results
                     //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
                     while (true) {
index 59899550ed38cc64003c4a540ff1ba0033cb530c..58763e39e835312729ec308e6c868217f1f7e3ba 100644 (file)
@@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext(
     constexpr short NW  = N_SIMDWIDTH;
     constexpr short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
 
-    const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
-    const short T  = DK + 2*TS; // shared memory size per query in (half)
+    const short TS = nsg*SH;      // shared memory size per query in (s_t == float)
+    const short T  = 2*DK + 2*TS; // shared memory size per query in (half)
 
-    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*DK); // holds the query data
-    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*DK); // same as above but in q4_t
-    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*DK); // reuse query data for accumulation
-    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*DK); // same as above but in o4_t
-    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
+    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +                0*DK); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +                0*DK); // same as above but in q4_t
+    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +                0*DK); // reuse query data for accumulation
+    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +                0*DK); // same as above but in o4_t
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
 
     threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
     threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
             if (iq1 + j < args.ne01) {
                 sq4[j*DK4 + i] = (q4_t) q4[i];
             } else {
-                sq4[j*DK4 + i] = (q4_t) 0.0f;
+                sq4[j*DK4 + i] = 0;
             }
         }
     }
@@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(
 
     // reduce the warps sequentially
     for (ushort sg = 1; sg < nsg; ++sg) {
-        float S = { 0.0f };
-        float M = { -__FLT_MAX__/2 };
-
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         // each simdgroup stores its output to shared memory, reusing sq
@@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
                 const float M0 = ss[j*TS +         1];
                 const float M1 = ss[j*TS + sg*SH + 1];
 
-                M = max(M0, M1);
+                const float M = max(M0, M1);
 
                 const float ms0 = exp(M0 - M);
                 const float ms1 = exp(M1 - M);
 
-                S = S0*ms0 + S1*ms1;
+                const float S = S0*ms0 + S1*ms1;
 
                 if (tiisg == 0) {
                     ss[j*TS + 0] = S;
@@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
         }
     }
 
-    device float4 * dst4 = (device float4 *) dst;
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
 
     // final rescale with 1/S and store to global memory
-    if (sgitg == 0) {
-        for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
-            const float S = ss[j*TS + 0];
+    for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
+        const float S = 1.0f/sf[j*TS + 0];
 
-            for (short i = tiisg; i < DV4; i += NW) {
-                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
-            }
+        device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
+
+        for (short i = tiisg; i < DV4; i += NW) {
+            dst4[i] = (float4) so4[j*DV4 + i]*S;
         }
     }
 }
@@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
 //       template to be able to explore different combinations
 //
 #define FA_TYPES \
-    half,  half4,   simdgroup_half8x8,  \
-    half,  half4x4, simdgroup_half8x8,  \
-    half,  half4x4, simdgroup_half8x8,  \
-    float,          simdgroup_float8x8, \
-    float,          simdgroup_float8x8, \
-    half,  half4,   simdgroup_half8x8
+    float,  float4,    simdgroup_float8x8, \
+    half,   half4x4,   simdgroup_half8x8,  \
+    half,   half4x4,   simdgroup_half8x8,  \
+    float,             simdgroup_float8x8, \
+    float,             simdgroup_float8x8, \
+    float,  float4,    simdgroup_float8x8
+    //half,   half4,     simdgroup_half8x8
+
+#define FA_TYPES_BF \
+    bfloat, bfloat4,   simdgroup_bfloat8x8, \
+    bfloat, bfloat4x4, simdgroup_bfloat8x8, \
+    bfloat, bfloat4x4, simdgroup_bfloat8x8, \
+    float,             simdgroup_float8x8,  \
+    float,             simdgroup_float8x8,  \
+    float,  float4,    simdgroup_float8x8
+    //half,   half4,     simdgroup_half8x8
 
 typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
 
@@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_at
 template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,    1, dequantize_f16,  576, 512>;
 
 #if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64,  64>;
-template [[host_name("kernel_flash_attn_ext_bf16_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80,  80>;
-template [[host_name("kernel_flash_attn_ext_bf16_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96,  96>;
-template [[host_name("kernel_flash_attn_ext_bf16_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 576, 512>;
+template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 64,  64>;
+template [[host_name("kernel_flash_attn_ext_bf16_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 80,  80>;
+template [[host_name("kernel_flash_attn_ext_bf16_h96" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 96,  96>;
+template [[host_name("kernel_flash_attn_ext_bf16_h112")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4,  1, dequantize_bf16, bfloat4x4,  1, dequantize_bf16, 576, 512>;
 #endif
 
 template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64,  64>;
@@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
 
 #undef FA_TYPES
+#undef FA_TYPES_BF
 
 template<
     typename q4_t,  // query types in shared memory
@@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec(
 
     const short T = DK + nsg*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                  0*DK); // holds the query data
-    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                  0*DK); // same as above but in q4_t
-    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 + sgitg*SH       + Q*DK); // scratch buffer for attention
-    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 + sgitg*SH       + Q*DK); // same as above but in s4_t
-    threadgroup float * sm  = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
-    threadgroup o4_t  * sr4 = (threadgroup o4_t  *) (shmem_f16 + sgitg*DV       + Q*T);  // scratch buffer for the results
+  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                    0*DK); // holds the query data
+    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                    0*DK); // same as above but in q4_t
+    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + Q*DK); // scratch buffer for attention
+    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + Q*DK); // same as above but in s4_t
+    threadgroup float * sm  = (threadgroup float *) (shmem_f16 +   sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
+    threadgroup o4_t  * sr4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*DV       + Q*T);  // scratch buffer for the results
 
     // store the result for all queries in local memory (the O matrix from the paper)
     o4_t lo[DV4/NL];
@@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
            half4,  \
     float,         \
     float, float4, \
-           half4
+           float4
 
 typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;