]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : optimize FA vec for large sequences and BS <= 8 (#15566)
authorGeorgi Gerganov <redacted>
Tue, 26 Aug 2025 11:22:14 +0000 (14:22 +0300)
committerGitHub <redacted>
Tue, 26 Aug 2025 11:22:14 +0000 (14:22 +0300)
* metal : optmize FA vec for large heads and sequences

* metal : adjust small-batch mul mv kernels

ggml-ci

* batched-bench : fix total speed computation

ggml-ci

* cont : add comments

ggml-ci

ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
tools/batched-bench/batched-bench.cpp

index 82c1ac1da0b13b490027fb0000cc92830ddcb314..b9d36394485009f48990a9775a76bd86fccf4e6f 100644 (file)
@@ -249,6 +249,7 @@ typedef struct {
     uint64_t nb33;
     int32_t  ne1;
     int32_t  ne2;
+    int32_t  ne3;
     float    scale;
     float    max_bias;
     float    m0;
@@ -257,6 +258,11 @@ typedef struct {
     float    logit_softcap;
 } ggml_metal_kargs_flash_attn_ext;
 
+typedef struct {
+    int32_t  nrows;
+    int32_t  ne20;
+} ggml_metal_kargs_flash_attn_ext_reduce;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne02;
index 7a05a98278f94699b68ef0d5f1471dea122770ac..1f93633d91f260e6a9821904b094a0c1a4f1227a 100644 (file)
@@ -291,6 +291,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -575,6 +579,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
     GGML_METAL_KERNEL_TYPE_SET_I32,
     GGML_METAL_KERNEL_TYPE_SET_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -1324,6 +1329,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,                 mul_mv_q5_1_f32,                 has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,                 mul_mv_q8_0_f32,                 has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,                mul_mv_mxfp4_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,         mul_mv_ext_f32_f32_r1_2,         has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,         mul_mv_ext_f32_f32_r1_3,         has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,         mul_mv_ext_f32_f32_r1_4,         has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,         mul_mv_ext_f32_f32_r1_5,         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,         mul_mv_ext_f16_f32_r1_2,         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,         mul_mv_ext_f16_f32_r1_3,         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,         mul_mv_ext_f16_f32_r1_4,         has_simdgroup_reduction);
@@ -1609,6 +1618,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,    flash_attn_ext_vec_q5_0_hk576_hv512,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,    flash_attn_ext_vec_q5_1_hk576_hv512,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,    flash_attn_ext_vec_q8_0_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,           flash_attn_ext_reduce,           has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32,                         set_f32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32,                         set_i32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                     cpy_f32_f32,                     true);
@@ -3385,15 +3395,16 @@ static int ggml_metal_encode_node(
 
                 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                 // to the matrix-vector kernel
-                const int ne11_mm_min = 4;
+                const int ne11_mm_min = 8;
 
                 // first try to use small-batch mat-mv kernels
                 // these should be efficient for BS [2, ~8]
-                if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
+                if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
                     (
                      (
                       (
-                       src0t == GGML_TYPE_F16  || // TODO: helper function
+                       src0t == GGML_TYPE_F32  || // TODO: helper function
+                       src0t == GGML_TYPE_F16  ||
                        src0t == GGML_TYPE_Q4_0 ||
                        src0t == GGML_TYPE_Q4_1 ||
                        src0t == GGML_TYPE_Q5_0 ||
@@ -3421,7 +3432,17 @@ static int ggml_metal_encode_node(
                     //       values and there can be some tail effects when nsg is high. need to confirm this
                     //
                     const int nsg    = 2;                 // num simdgroups per threadgroup
-                    const int nxpsg  = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
+
+                    // num threads along row per simdgroup
+                    int nxpsg = 0;
+                    if (ne00 % 256 == 0 && ne11 < 3) {
+                        nxpsg = 16;
+                    } else if (ne00 % 128 == 0) {
+                        nxpsg = 8;
+                    } else {
+                        nxpsg = 4;
+                    }
+
                     const int nypsg  = 32/nxpsg;          // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
                     const int r0ptg  = nypsg*nsg;         // num src0 rows per threadgroup
                           int r1ptg  = 4;                 // num src1 rows per threadgroup
@@ -3444,6 +3465,14 @@ static int ggml_metal_encode_node(
                     id<MTLComputePipelineState> pipeline = nil;
 
                     switch (src0->type) {
+                        case GGML_TYPE_F32:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
                         case GGML_TYPE_F16:
                             switch (r1ptg) {
                                 case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
@@ -3598,7 +3627,7 @@ static int ggml_metal_encode_node(
                         case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break;
                         case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break;
                         case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break;
-                        case GGML_TYPE_MXFP4:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32   ].pipeline; break;
+                        case GGML_TYPE_MXFP4:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32  ].pipeline; break;
                         case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break;
                         case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break;
                         case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break;
@@ -5482,6 +5511,7 @@ static int ggml_metal_encode_node(
                     /*.nb33          =*/ nb33,
                     /*.ne1           =*/ ne1,
                     /*.ne2           =*/ ne2,
+                    /*.ne3           =*/ ne3,
                     /*.scale         =*/ scale,
                     /*.max_bias      =*/ max_bias,
                     /*.m0            =*/ m0,
@@ -5505,7 +5535,6 @@ static int ggml_metal_encode_node(
                 } else {
                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
                 }
-                [encoder setBuffer:id_dst offset:offs_dst       atIndex:6];
 
                 if (!use_vec_kernel) {
                     // half8x8 kernel
@@ -5531,7 +5560,7 @@ static int ggml_metal_encode_node(
 
                     while (true) {
                         const size_t smem = FATTN_SMEM(nsgmax);
-                        if (smem > device.maxThreadgroupMemoryLength) {
+                        if (smem > device.maxThreadgroupMemoryLength/2) {
                             break;
                         }
                         nsgmax *= 2;
@@ -5543,15 +5572,18 @@ static int ggml_metal_encode_node(
 
                     const size_t smem = FATTN_SMEM(nsg);
 
+                    [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
+
                     //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
                     GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
                     [encoder setThreadgroupMemoryLength:smem atIndex:0];
-#undef FATTN_SMEM
                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+#undef FATTN_SMEM
                 } else {
                     // half4x4 kernel
                     const int64_t nqptg = 1;  // queries per threadgroup    !! sync with kernel template arguments !!
                     const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+                    const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
 
                     GGML_ASSERT(nqptg <= 32);
                     GGML_ASSERT(nqptg  % 1  == 0);
@@ -5561,15 +5593,17 @@ static int ggml_metal_encode_node(
                     // for each query, we load it as f16 in shared memory (ne00)
                     // and store the soft_max values and the mask
                     //
-                    // ne00*(nsg)
+                    // ne20*(nsg)
                     // each simdgroup has a full f32 head vector in shared mem to accumulate results
                     //
 #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
+//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
 
                     int64_t nsgmax = 2;
                     while (true) {
                         const size_t smem = FATTN_SMEM(nsgmax);
-                        if (smem > device.maxThreadgroupMemoryLength) {
+                        // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
+                        if (smem > device.maxThreadgroupMemoryLength/2) {
                             break;
                         }
                         nsgmax *= 2;
@@ -5577,7 +5611,7 @@ static int ggml_metal_encode_node(
                     nsgmax /= 2;
 
                     // simdgroups per threadgroup (a.k.a. warps)
-                    const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+                    const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
 
                     int64_t nsg = 1;
                     while (nsg <= nsgt) {
@@ -5585,13 +5619,74 @@ static int ggml_metal_encode_node(
                     }
                     nsg /= 2;
 
-                    const size_t smem = FATTN_SMEM(nsg);
+                    // workgroups
+                    // each workgroup handles nsg*nkpsg cache values
+                    uint16_t nwg = 1;
+                    if (4*nsg*nkpsg >= ne11) {
+                        const size_t smem = FATTN_SMEM(nsg);
 
-                    //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
-                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
-                    [encoder setThreadgroupMemoryLength:smem atIndex:0];
+                        //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+                        GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+                        // using 1 workgroup -> write the result directly into dst
+                        [encoder setBuffer:id_dst offset:offs_dst      atIndex:6];
+                        [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                    } else {
+                        nwg = 32;
+                        nsg = MIN(4, nsg);
+
+                        const size_t smem = FATTN_SMEM(nsg);
+
+                        //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+                        GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+                        // sanity checks
+                        GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
+                        GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
+
+                        const int32_t nrows = ne1*ne2*ne3;
+
+                        // temp buffer for writing the results from each workgroup
+                        // - ne20: the size of the head vector
+                        // -  + 2: the S and M values for each intermediate result
+                        const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
+                        id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
+                        if (!h_tmp) {
+                            GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
+                            return 0;
+                        }
+
+                        //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
+                        //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
+
+                        [encoder setBuffer:h_tmp  offset:0             atIndex:6];
+                        [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
+                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+                        // reduce the results from the workgroups
+                        {
+                            ggml_metal_kargs_flash_attn_ext_reduce args0 = {
+                                nrows,
+                                ne20,
+                            };
+
+                            id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
+
+                            [encoder setComputePipelineState:pipeline0];
+                            [encoder setBytes:&args0   length:sizeof(args0) atIndex:0];
+                            [encoder setBuffer:h_tmp   offset:0             atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst      atIndex:2];
+
+                            //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
+                            [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
+                        }
+                    }
 #undef FATTN_SMEM
-                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
                 }
             } break;
         case GGML_OP_DUP:
index 7037c1aa0d30d741bfdaa86f9eaced5cea37927f..fa80d6e405978dd47a271890343c2a804574327b 100644 (file)
@@ -68,6 +68,11 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg)
     reg = (type4x4)(*src);
 }
 
+template <typename type4>
+void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
+    reg = (type4)(*src);
+}
+
 template <typename type4x4>
 void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
     reg = (type4x4)(*src);
@@ -3015,7 +3020,6 @@ void kernel_mul_mv_ext_q4_f32_impl(
 #pragma unroll(r1ptg)
             for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
                 sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
-
             }
         }
 
@@ -3200,6 +3204,11 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
 typedef decltype(kernel_mul_mv_ext_q4_f32_disp  <2, block_q8_0, 32,  dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
 typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>)    mul_mv_ext_q4x4_f32_t;
 
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4,       4,  dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4,       4,  dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4,       4,  dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4,       4,  dequantize_f32_t4>;
+
 template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4,        4,  dequantize_f16_t4>;
 template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4,        4,  dequantize_f16_t4>;
 template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4,        4,  dequantize_f16_t4>;
@@ -4786,14 +4795,16 @@ kernel void kernel_flash_attn_ext_vec(
         device const char * mask,
         device const char * sinks,
         device       char * dst,
+        constant uint16_t & nwg,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
         ushort3   ntg[[threads_per_threadgroup]],
         ushort  tiisg[[thread_index_in_simdgroup]],
         ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
     const short nsg = ntg.y; // number of simdgroups
+    const short iwg = tgpig[2]%nwg;
 
-    const int iq3 = tgpig[2];
+    const int iq3 = tgpig[2]/nwg;
     const int iq2 = tgpig[1];
     const int iq1 = tgpig[0];
 
@@ -4872,7 +4883,7 @@ kernel void kernel_flash_attn_ext_vec(
 
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
-        for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+        for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
             const int ic = ic0 + C*sgitg;
             if (ic >= args.ne11) {
                 break;
@@ -5002,7 +5013,7 @@ kernel void kernel_flash_attn_ext_vec(
             }
         }
 
-        if (sinks != q && sgitg == 0) {
+        if (sinks != q && sgitg == 0 && iwg == 0) {
             const float m = M;
             const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
 
@@ -5111,14 +5122,25 @@ kernel void kernel_flash_attn_ext_vec(
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    device float4 * dst4 = (device float4 *) dst;
-
     // final rescale with 1/S and store to global memory
     if (sgitg == 0) {
-        const float S = ss[0];
+        const int64_t nrows = args.ne3*args.ne2*args.ne1;
+        const int64_t rid   = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
+
+        device float4 * dst4 = (device float4 *) dst;
+        device float  * dst1 = (device float  *) dst + nrows*DV*nwg; // the S and M are stored after the results
+
+        const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
 
+        // interleave the workgroup data
         for (short i = tiisg; i < DV4; i += NW) {
-            dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
+            dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
+        }
+
+        // store S and M
+        if (nwg > 1 && tiisg == 0) {
+            dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
+            dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
         }
     }
 }
@@ -5218,6 +5240,41 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flas
 
 #undef FA_TYPES
 
+kernel void kernel_flash_attn_ext_reduce(
+        constant ggml_metal_kargs_flash_attn_ext_reduce & args,
+        device  const char * htmp,
+        device        char * dst,
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const uint64_t rid = tgpig;
+
+    const short nwg = 32;
+    const short iwg = tiisg;
+    const short DV  = args.ne20;
+    const short DV4 = DV/4;
+
+    device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
+    device const float  * ss    = (device const float  *) htmp + (uint64_t)args.nrows*DV*nwg;
+    device       float4 * dst4  = (device       float4 *) dst  + rid*DV4;
+
+    float S = ss[rid*(2*nwg) + 2*iwg + 0];
+    float M = ss[rid*(2*nwg) + 2*iwg + 1];
+
+    const float m  = simd_max(M);
+    const float ms = exp(M - m);
+
+    S = 1.0f/simd_sum(S*ms);
+
+    for (int i = sgitg; i < DV4; i += nwg) {
+        const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
+
+        if (iwg == 0) {
+            dst4[i] = v*S;
+        }
+    }
+}
+
 template<typename T>
 kernel void kernel_set(
     constant ggml_metal_kargs_set & args,
index 93efad3280913d1a4b67d2b210d2fd748424da49..23d03039dcc036f0c28078cb5e6666d3b378c86f 100644 (file)
@@ -191,7 +191,7 @@ int main(int argc, char ** argv) {
 
                 const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp;
                 const float speed_tg = pl*tg / t_tg;
-                const float speed    = n_kv / t;
+                const float speed    = ((is_pp_shared ? pp : pl*pp) + pl*tg) / t;
 
                 if(params.batched_bench_output_jsonl) {
                     LOG(