GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
GGML_METAL_KERNEL_TYPE_SET_I32,
GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- const int ne11_mm_min = 4;
+ const int ne11_mm_min = 8;
// first try to use small-batch mat-mv kernels
// these should be efficient for BS [2, ~8]
- if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
+ if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
(
(
(
- src0t == GGML_TYPE_F16 || // TODO: helper function
+ src0t == GGML_TYPE_F32 || // TODO: helper function
+ src0t == GGML_TYPE_F16 ||
src0t == GGML_TYPE_Q4_0 ||
src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 ||
// values and there can be some tail effects when nsg is high. need to confirm this
//
const int nsg = 2; // num simdgroups per threadgroup
- const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
+
+ // num threads along row per simdgroup
+ int nxpsg = 0;
+ if (ne00 % 256 == 0 && ne11 < 3) {
+ nxpsg = 16;
+ } else if (ne00 % 128 == 0) {
+ nxpsg = 8;
+ } else {
+ nxpsg = 4;
+ }
+
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
int r1ptg = 4; // num src1 rows per threadgroup
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
+ case GGML_TYPE_F32:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
case GGML_TYPE_F16:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
- case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
if (!use_vec_kernel) {
// half8x8 kernel
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength) {
+ if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
const size_t smem = FATTN_SMEM(nsg);
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
+
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
-#undef FATTN_SMEM
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+#undef FATTN_SMEM
} else {
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+ const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
//
- // ne00*(nsg)
+ // ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
+//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength) {
+ // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
+ if (smem > device.maxThreadgroupMemoryLength/2) {
break;
}
nsgmax *= 2;
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
}
nsg /= 2;
- const size_t smem = FATTN_SMEM(nsg);
+ // workgroups
+ // each workgroup handles nsg*nkpsg cache values
+ uint16_t nwg = 1;
+ if (4*nsg*nkpsg >= ne11) {
+ const size_t smem = FATTN_SMEM(nsg);
- //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
- [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+ // using 1 workgroup -> write the result directly into dst
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
+ [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ nwg = 32;
+ nsg = MIN(4, nsg);
+
+ const size_t smem = FATTN_SMEM(nsg);
+
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+ // sanity checks
+ GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
+ GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
+
+ const int32_t nrows = ne1*ne2*ne3;
+
+ // temp buffer for writing the results from each workgroup
+ // - ne20: the size of the head vector
+ // - + 2: the S and M values for each intermediate result
+ const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
+ id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
+ if (!h_tmp) {
+ GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
+ return 0;
+ }
+
+ //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
+ //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
+
+ [encoder setBuffer:h_tmp offset:0 atIndex:6];
+ [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+ // reduce the results from the workgroups
+ {
+ ggml_metal_kargs_flash_attn_ext_reduce args0 = {
+ nrows,
+ ne20,
+ };
+
+ id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
+
+ [encoder setComputePipelineState:pipeline0];
+ [encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
+ [encoder setBuffer:h_tmp offset:0 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+
+ //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
+ }
+ }
#undef FATTN_SMEM
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
}
} break;
case GGML_OP_DUP:
reg = (type4x4)(*src);
}
+template <typename type4>
+void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
+ reg = (type4)(*src);
+}
+
template <typename type4x4>
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
reg = (type4x4)(*src);
#pragma unroll(r1ptg)
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
-
}
}
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
+template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
+
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
device const char * mask,
device const char * sinks,
device char * dst,
+ constant uint16_t & nwg,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
+ const short iwg = tgpig[2]%nwg;
- const int iq3 = tgpig[2];
+ const int iq3 = tgpig[2]/nwg;
const int iq2 = tgpig[1];
const int iq1 = tgpig[0];
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
+ for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
const int ic = ic0 + C*sgitg;
if (ic >= args.ne11) {
break;
}
}
- if (sinks != q && sgitg == 0) {
+ if (sinks != q && sgitg == 0 && iwg == 0) {
const float m = M;
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- device float4 * dst4 = (device float4 *) dst;
-
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
- const float S = ss[0];
+ const int64_t nrows = args.ne3*args.ne2*args.ne1;
+ const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
+
+ device float4 * dst4 = (device float4 *) dst;
+ device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results
+
+ const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
+ // interleave the workgroup data
for (short i = tiisg; i < DV4; i += NW) {
- dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
+ dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
+ }
+
+ // store S and M
+ if (nwg > 1 && tiisg == 0) {
+ dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
+ dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
}
}
}
#undef FA_TYPES
+kernel void kernel_flash_attn_ext_reduce(
+ constant ggml_metal_kargs_flash_attn_ext_reduce & args,
+ device const char * htmp,
+ device char * dst,
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const uint64_t rid = tgpig;
+
+ const short nwg = 32;
+ const short iwg = tiisg;
+ const short DV = args.ne20;
+ const short DV4 = DV/4;
+
+ device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
+ device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg;
+ device float4 * dst4 = (device float4 *) dst + rid*DV4;
+
+ float S = ss[rid*(2*nwg) + 2*iwg + 0];
+ float M = ss[rid*(2*nwg) + 2*iwg + 1];
+
+ const float m = simd_max(M);
+ const float ms = exp(M - m);
+
+ S = 1.0f/simd_sum(S*ms);
+
+ for (int i = sgitg; i < DV4; i += nwg) {
+ const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
+
+ if (iwg == 0) {
+ dst4[i] = v*S;
+ }
+ }
+}
+
template<typename T>
kernel void kernel_set(
constant ggml_metal_kargs_set & args,