From: Georgi Gerganov Date: Tue, 26 Aug 2025 11:22:14 +0000 (+0300) Subject: metal : optimize FA vec for large sequences and BS <= 8 (#15566) X-Git-Tag: upstream/0.0.6527~241 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=b3964c1e890ef8c947afb36a5124ce6fcb2136d4;p=pkg%2Fggml%2Fsources%2Fllama.cpp metal : optimize FA vec for large sequences and BS <= 8 (#15566) * metal : optmize FA vec for large heads and sequences * metal : adjust small-batch mul mv kernels ggml-ci * batched-bench : fix total speed computation ggml-ci * cont : add comments ggml-ci --- diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 82c1ac1d..b9d36394 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -249,6 +249,7 @@ typedef struct { uint64_t nb33; int32_t ne1; int32_t ne2; + int32_t ne3; float scale; float max_bias; float m0; @@ -257,6 +258,11 @@ typedef struct { float logit_softcap; } ggml_metal_kargs_flash_attn_ext; +typedef struct { + int32_t nrows; + int32_t ne20; +} ggml_metal_kargs_flash_attn_ext_reduce; + typedef struct { int32_t ne00; int32_t ne02; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 7a05a982..1f93633d 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -291,6 +291,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, @@ -575,6 +579,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, GGML_METAL_KERNEL_TYPE_SET_I32, GGML_METAL_KERNEL_TYPE_SET_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, @@ -1324,6 +1329,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); @@ -1609,6 +1618,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); @@ -3385,15 +3395,16 @@ static int ggml_metal_encode_node( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - const int ne11_mm_min = 4; + const int ne11_mm_min = 8; // first try to use small-batch mat-mv kernels // these should be efficient for BS [2, ~8] - if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && + if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) && ( ( ( - src0t == GGML_TYPE_F16 || // TODO: helper function + src0t == GGML_TYPE_F32 || // TODO: helper function + src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || @@ -3421,7 +3432,17 @@ static int ggml_metal_encode_node( // values and there can be some tail effects when nsg is high. need to confirm this // const int nsg = 2; // num simdgroups per threadgroup - const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup + + // num threads along row per simdgroup + int nxpsg = 0; + if (ne00 % 256 == 0 && ne11 < 3) { + nxpsg = 16; + } else if (ne00 % 128 == 0) { + nxpsg = 8; + } else { + nxpsg = 4; + } + const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup int r1ptg = 4; // num src1 rows per threadgroup @@ -3444,6 +3465,14 @@ static int ggml_metal_encode_node( id pipeline = nil; switch (src0->type) { + case GGML_TYPE_F32: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; case GGML_TYPE_F16: switch (r1ptg) { case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break; @@ -3598,7 +3627,7 @@ static int ggml_metal_encode_node( case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break; case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; @@ -5482,6 +5511,7 @@ static int ggml_metal_encode_node( /*.nb33 =*/ nb33, /*.ne1 =*/ ne1, /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, /*.scale =*/ scale, /*.max_bias =*/ max_bias, /*.m0 =*/ m0, @@ -5505,7 +5535,6 @@ static int ggml_metal_encode_node( } else { [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; if (!use_vec_kernel) { // half8x8 kernel @@ -5531,7 +5560,7 @@ static int ggml_metal_encode_node( while (true) { const size_t smem = FATTN_SMEM(nsgmax); - if (smem > device.maxThreadgroupMemoryLength) { + if (smem > device.maxThreadgroupMemoryLength/2) { break; } nsgmax *= 2; @@ -5543,15 +5572,18 @@ static int ggml_metal_encode_node( const size_t smem = FATTN_SMEM(nsg); + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; -#undef FATTN_SMEM [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; +#undef FATTN_SMEM } else { // half4x4 kernel const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); @@ -5561,15 +5593,17 @@ static int ggml_metal_encode_node( // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask // - // ne00*(nsg) + // ne20*(nsg) // each simdgroup has a full f32 head vector in shared mem to accumulate results // #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) +//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; while (true) { const size_t smem = FATTN_SMEM(nsgmax); - if (smem > device.maxThreadgroupMemoryLength) { + // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes + if (smem > device.maxThreadgroupMemoryLength/2) { break; } nsgmax *= 2; @@ -5577,7 +5611,7 @@ static int ggml_metal_encode_node( nsgmax /= 2; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); int64_t nsg = 1; while (nsg <= nsgt) { @@ -5585,13 +5619,74 @@ static int ggml_metal_encode_node( } nsg /= 2; - const size_t smem = FATTN_SMEM(nsg); + // workgroups + // each workgroup handles nsg*nkpsg cache values + uint16_t nwg = 1; + if (4*nsg*nkpsg >= ne11) { + const size_t smem = FATTN_SMEM(nsg); - //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); - GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:smem atIndex:0]; + //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + + // using 1 workgroup -> write the result directly into dst + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7]; + + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + nwg = 32; + nsg = MIN(4, nsg); + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax); + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + + // sanity checks + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); + GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31)); + + const int32_t nrows = ne1*ne2*ne3; + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the head vector + // - + 2: the S and M values for each intermediate result + const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2)); + id h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp); + if (!h_tmp) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp); + return 0; + } + + //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20); + //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f); + + [encoder setBuffer:h_tmp offset:0 atIndex:6]; + [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7]; + + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + // reduce the results from the workgroups + { + ggml_metal_kargs_flash_attn_ext_reduce args0 = { + nrows, + ne20, + }; + + id pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline; + + [encoder setComputePipelineState:pipeline0]; + [encoder setBytes:&args0 length:sizeof(args0) atIndex:0]; + [encoder setBuffer:h_tmp offset:0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20); + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)]; + } + } #undef FATTN_SMEM - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } } break; case GGML_OP_DUP: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 7037c1aa..fa80d6e4 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -68,6 +68,11 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) reg = (type4x4)(*src); } +template +void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) { + reg = (type4)(*src); +} + template void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); @@ -3015,7 +3020,6 @@ void kernel_mul_mv_ext_q4_f32_impl( #pragma unroll(r1ptg) for (short ir1 = 0; ir1 < r1ptg; ++ir1) { sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); - } } @@ -3200,6 +3204,11 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp( typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>; +template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>; + template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; @@ -4786,14 +4795,16 @@ kernel void kernel_flash_attn_ext_vec( device const char * mask, device const char * sinks, device char * dst, + constant uint16_t & nwg, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups + const short iwg = tgpig[2]%nwg; - const int iq3 = tgpig[2]; + const int iq3 = tgpig[2]/nwg; const int iq2 = tgpig[1]; const int iq1 = tgpig[0]; @@ -4872,7 +4883,7 @@ kernel void kernel_flash_attn_ext_vec( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { + for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) { const int ic = ic0 + C*sgitg; if (ic >= args.ne11) { break; @@ -5002,7 +5013,7 @@ kernel void kernel_flash_attn_ext_vec( } } - if (sinks != q && sgitg == 0) { + if (sinks != q && sgitg == 0 && iwg == 0) { const float m = M; const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; @@ -5111,14 +5122,25 @@ kernel void kernel_flash_attn_ext_vec( threadgroup_barrier(mem_flags::mem_threadgroup); } - device float4 * dst4 = (device float4 *) dst; - // final rescale with 1/S and store to global memory if (sgitg == 0) { - const float S = ss[0]; + const int64_t nrows = args.ne3*args.ne2*args.ne1; + const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1; + + device float4 * dst4 = (device float4 *) dst; + device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results + + const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f; + // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S; + dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S; + } + + // store S and M + if (nwg > 1 && tiisg == 0) { + dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0]; + dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1]; } } } @@ -5218,6 +5240,41 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flas #undef FA_TYPES +kernel void kernel_flash_attn_ext_reduce( + constant ggml_metal_kargs_flash_attn_ext_reduce & args, + device const char * htmp, + device char * dst, + uint tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const uint64_t rid = tgpig; + + const short nwg = 32; + const short iwg = tiisg; + const short DV = args.ne20; + const short DV4 = DV/4; + + device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg; + device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg; + device float4 * dst4 = (device float4 *) dst + rid*DV4; + + float S = ss[rid*(2*nwg) + 2*iwg + 0]; + float M = ss[rid*(2*nwg) + 2*iwg + 1]; + + const float m = simd_max(M); + const float ms = exp(M - m); + + S = 1.0f/simd_sum(S*ms); + + for (int i = sgitg; i < DV4; i += nwg) { + const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms); + + if (iwg == 0) { + dst4[i] = v*S; + } + } +} + template kernel void kernel_set( constant ggml_metal_kargs_set & args, diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index 93efad32..23d03039 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -191,7 +191,7 @@ int main(int argc, char ** argv) { const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp; const float speed_tg = pl*tg / t_tg; - const float speed = n_kv / t; + const float speed = ((is_pp_shared ? pp : pl*pp) + pl*tg) / t; if(params.batched_bench_output_jsonl) { LOG(