]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : minor cleanup (#19251)
authorGeorgi Gerganov <redacted>
Tue, 3 Feb 2026 11:43:29 +0000 (13:43 +0200)
committerGitHub <redacted>
Tue, 3 Feb 2026 11:43:29 +0000 (13:43 +0200)
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index 59d88b01a5581837d0b91abfd814f7cba40a8461..e074f2ef3db157c75bc9ca5b1c36d3e05db96178 100644 (file)
 #define FC_COUNT_EQUAL                 1000
 
 // op-specific constants
-#define OP_FLASH_ATTN_EXT_NQPTG 8
+#define OP_FLASH_ATTN_EXT_NQPSG 8
 #define OP_FLASH_ATTN_EXT_NCPSG 64
 
-#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
+#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
 
 // kernel argument structs
index 7f4cfbba226c6ebbc314ce7356259d7f526adde6..f97c4435dec9f2d818702bdd01ab9bb27fd010b0 100644 (file)
@@ -2295,7 +2295,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
     //    return res;
     //}
 
-    const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
+    const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
     const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
 
     const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
@@ -2411,7 +2411,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
     if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
         // half8x8 kernel
-        const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
+        const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
         const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
 
         GGML_ASSERT(nqptg <= 32);
@@ -2578,9 +2578,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 #undef FATTN_SMEM
     } else {
         // half4x4 kernel
-        const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
+        const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
         const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
-        const int nkpsg = 1*ncpsg;
+        const int nhptg = 1;                           // heads per threadgroup
 
         GGML_ASSERT(nqptg <= 32);
         GGML_ASSERT(nqptg  % 1  == 0);
@@ -2632,6 +2632,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             ggml_metal_op_concurrency_reset(ctx);
         }
 
+        // note: for simplicity assume the K is larger or equal than V
+        GGML_ASSERT(ne10 >= ne20);
+
         // ne00 + 2*ncpsg*(nsg)
         // for each query, we load it as f16 in shared memory (ne00)
         // and store the soft_max values and the mask
@@ -2639,28 +2642,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         // ne20*(nsg)
         // each simdgroup has a full f32 head vector in shared mem to accumulate results
         //
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
-
-        int64_t nsgmax = 2;
-        while (true) {
-            const size_t smem = FATTN_SMEM(nsgmax);
-            // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
-            if (smem > props_dev->max_theadgroup_memory_size/2) {
-                break;
-            }
-            nsgmax *= 2;
-        }
-        nsgmax /= 2;
-
-        // simdgroups per threadgroup (a.k.a. warps)
-        //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
-        const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
+#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
 
         int64_t nsg = 1;
-        while (nsg <= nsgt) {
-            nsg *= 2;
-        }
-        nsg /= 2;
 
         // workgroups
         // each workgroup handles nsg*nkpsg cache values
@@ -2673,7 +2657,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
         } else {
             nwg = 32;
             nsg = 1;
-            while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
+            while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
                 nsg *= 2;
             }
         }
@@ -2739,7 +2723,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
             ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
-            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
+            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
         } else {
             // sanity checks
             assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
@@ -2752,7 +2736,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
 
             ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
-            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
+            ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
 
             // sync the 2 kernels
             ggml_metal_op_concurrency_reset(ctx);
index 17e358d1a8d824963b5976d111eb15807e2fd1a5..3259213fd6160a48ebe66a1ef2265518961e77b8 100644 (file)
@@ -5931,7 +5931,7 @@ template<
     void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
     short DK,         // K head size
     short DV,         // V head size
-    short Q  = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
+    short Q  = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
     short C  = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
 kernel void kernel_flash_attn_ext(
         constant ggml_metal_kargs_flash_attn_ext & args,
@@ -6141,11 +6141,10 @@ template<
     void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
     short DK,       // K head size
     short DV,       // V head size
-    short NE,       // head elements per thread
-    short Q,        // queries per threadgroup
-    short C,        // cache items per threadgroup
-    short NSG>      // number of simd groups
-void kernel_flash_attn_ext_vec_impl(
+    short NE = 4,   // head elements per thread
+    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPSG,  // queries per threadgroup
+    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
         constant ggml_metal_kargs_flash_attn_ext_vec & args,
         device const char * q,
         device const char * k,
@@ -6162,6 +6161,7 @@ void kernel_flash_attn_ext_vec_impl(
     static_assert(DV % 32 == 0, "DV must be divisible by 32");
 
 #define NWG  (FC_flash_attn_ext_vec_nwg)
+#define NSG  (FC_flash_attn_ext_vec_nsg)
 
 #define NS10 (FC_flash_attn_ext_vec_ns10)
 #define NS20 (FC_flash_attn_ext_vec_ns20)
@@ -6190,12 +6190,12 @@ void kernel_flash_attn_ext_vec_impl(
 
     const short T = PK + NSG*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                    0*PK); // holds the query data
-    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                    0*PK); // same as above but in q4_t
-    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + Q*PK); // scratch buffer for attention
-    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + Q*PK); // same as above but in s4_t
-    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
-    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + Q*T);  // scratch buffer for the results
+  //threadgroup q_t   * sq  = (threadgroup q_t   *) (shmem_f16 +                      0*PK); // holds the query data
+    threadgroup q4_t  * sq4 = (threadgroup q4_t  *) (shmem_f16 +                      0*PK); // same as above but in q4_t
+    threadgroup s_t   * ss  = (threadgroup s_t   *) (shmem_f16 +   sgitg*SH       + NSG*PK); // scratch buffer for attention
+    threadgroup s4_t  * ss4 = (threadgroup s4_t  *) (shmem_f16 +   sgitg*SH       + NSG*PK); // same as above but in s4_t
+    threadgroup half  * sm  = (threadgroup half  *) (shmem_f16 +   sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
+    threadgroup o4_t  * so4 = (threadgroup o4_t  *) (shmem_f16 + 2*sgitg*PV       + NSG*PK + NSG*SH); // scratch buffer for the results
 
     // store the result for all queries in shared memory (the O matrix from the paper)
     so4 += tiisg;
@@ -6213,11 +6213,13 @@ void kernel_flash_attn_ext_vec_impl(
     // load heads from Q to shared memory
     device const float4 * q4 = (device const float4 *) ((device const char *) q);
 
-    for (short i = tiisg; i < PK4; i += NW) {
-        if (iq1 < args.ne01 && i < DK4) {
-            sq4[i] = (q4_t) q4[i];
-        } else {
-            sq4[i] = (q4_t) 0.0f;
+    if (iq1 < args.ne01) {
+        for (short i = tiisg; i < PK4; i += NW) {
+            if (i < DK4) {
+                sq4[i] = (q4_t) q4[i];
+            } else {
+                sq4[i] = (q4_t) 0.0f;
+            }
         }
     }
 
@@ -6295,7 +6297,7 @@ void kernel_flash_attn_ext_vec_impl(
             }
 
             // skip -INF blocks
-            if (simd_max(sm[tiisg]) == -INFINITY) {
+            if (simd_max(sm[tiisg]) <= -MAXHALF) {
                 continue;
             }
 
@@ -6569,57 +6571,11 @@ void kernel_flash_attn_ext_vec_impl(
     }
 
 #undef NWG
+#undef NSG
 #undef NS10
 #undef NS20
 }
 
-template<
-    typename q4_t,  // query types in shared memory
-    typename k4_t,  // key types in shared memory
-    typename v4_t,  // value types in shared memory
-    typename qk_t,  // Q*K types
-    typename s_t,   // soft-max types
-    typename s4_t,
-    typename o4_t,  // attention accumulation types
-    typename kd4_t, // key type in device memory
-    short nl_k,
-    void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
-    typename vd4_t, // value type in device memory
-    short nl_v,
-    void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
-    short DK,       // K head size
-    short DV,       // V head size
-    short NE = 4,   // head elements per thread
-    short Q  = OP_FLASH_ATTN_EXT_VEC_NQPTG,  // queries per threadgroup
-    short C  = OP_FLASH_ATTN_EXT_VEC_NCPSG>  // cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec(
-        constant ggml_metal_kargs_flash_attn_ext_vec & args,
-        device const char * q,
-        device const char * k,
-        device const char * v,
-        device const char * mask,
-        device const char * sinks,
-        device const char * pad,
-        device       char * dst,
-        threadgroup  half * shmem_f16 [[threadgroup(0)]],
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort  tiisg[[thread_index_in_simdgroup]],
-        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
-#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
-#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
-    switch (FC_flash_attn_ext_vec_nsg) {
-      // note: disabled cases to reduce library load time
-        case 1:  kernel_flash_attn_ext_vec_impl<FWD_TMPL,  1>(FWD_ARGS); break;
-        case 2:  kernel_flash_attn_ext_vec_impl<FWD_TMPL,  2>(FWD_ARGS); break;
-        case 4:  kernel_flash_attn_ext_vec_impl<FWD_TMPL,  4>(FWD_ARGS); break;
-      //case 8:  kernel_flash_attn_ext_vec_impl<FWD_TMPL,  8>(FWD_ARGS); break;
-      //case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
-      //case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
-    }
-#undef FWD_TMPL
-#undef FWD_ARGS
-}
-
 // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
 //       in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
 //