// return res;
//}
- const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
+ const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
// half8x8 kernel
- const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
+ const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
GGML_ASSERT(nqptg <= 32);
#undef FATTN_SMEM
} else {
// half4x4 kernel
- const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
+ const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
- const int nkpsg = 1*ncpsg;
+ const int nhptg = 1; // heads per threadgroup
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
ggml_metal_op_concurrency_reset(ctx);
}
+ // note: for simplicity assume the K is larger or equal than V
+ GGML_ASSERT(ne10 >= ne20);
+
// ne00 + 2*ncpsg*(nsg)
// for each query, we load it as f16 in shared memory (ne00)
// and store the soft_max values and the mask
// ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
-
- int64_t nsgmax = 2;
- while (true) {
- const size_t smem = FATTN_SMEM(nsgmax);
- // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
- if (smem > props_dev->max_theadgroup_memory_size/2) {
- break;
- }
- nsgmax *= 2;
- }
- nsgmax /= 2;
-
- // simdgroups per threadgroup (a.k.a. warps)
- //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
+#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
int64_t nsg = 1;
- while (nsg <= nsgt) {
- nsg *= 2;
- }
- nsg /= 2;
// workgroups
// each workgroup handles nsg*nkpsg cache values
} else {
nwg = 32;
nsg = 1;
- while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
+ while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
nsg *= 2;
}
}
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
- ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
} else {
// sanity checks
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
- ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
+ ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
// sync the 2 kernels
ggml_metal_op_concurrency_reset(ctx);
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
short DK, // K head size
short DV, // V head size
- short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
+ short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
kernel void kernel_flash_attn_ext(
constant ggml_metal_kargs_flash_attn_ext & args,
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
short DK, // K head size
short DV, // V head size
- short NE, // head elements per thread
- short Q, // queries per threadgroup
- short C, // cache items per threadgroup
- short NSG> // number of simd groups
-void kernel_flash_attn_ext_vec_impl(
+ short NE = 4, // head elements per thread
+ short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup
+ short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
constant ggml_metal_kargs_flash_attn_ext_vec & args,
device const char * q,
device const char * k,
static_assert(DV % 32 == 0, "DV must be divisible by 32");
#define NWG (FC_flash_attn_ext_vec_nwg)
+#define NSG (FC_flash_attn_ext_vec_nsg)
#define NS10 (FC_flash_attn_ext_vec_ns10)
#define NS20 (FC_flash_attn_ext_vec_ns20)
const short T = PK + NSG*SH; // shared memory size per query in (half)
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t
+ threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results
// store the result for all queries in shared memory (the O matrix from the paper)
so4 += tiisg;
// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q);
- for (short i = tiisg; i < PK4; i += NW) {
- if (iq1 < args.ne01 && i < DK4) {
- sq4[i] = (q4_t) q4[i];
- } else {
- sq4[i] = (q4_t) 0.0f;
+ if (iq1 < args.ne01) {
+ for (short i = tiisg; i < PK4; i += NW) {
+ if (i < DK4) {
+ sq4[i] = (q4_t) q4[i];
+ } else {
+ sq4[i] = (q4_t) 0.0f;
+ }
}
}
}
// skip -INF blocks
- if (simd_max(sm[tiisg]) == -INFINITY) {
+ if (simd_max(sm[tiisg]) <= -MAXHALF) {
continue;
}
}
#undef NWG
+#undef NSG
#undef NS10
#undef NS20
}
-template<
- typename q4_t, // query types in shared memory
- typename k4_t, // key types in shared memory
- typename v4_t, // value types in shared memory
- typename qk_t, // Q*K types
- typename s_t, // soft-max types
- typename s4_t,
- typename o4_t, // attention accumulation types
- typename kd4_t, // key type in device memory
- short nl_k,
- void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
- typename vd4_t, // value type in device memory
- short nl_v,
- void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
- short DK, // K head size
- short DV, // V head size
- short NE = 4, // head elements per thread
- short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
- short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec(
- constant ggml_metal_kargs_flash_attn_ext_vec & args,
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device const char * sinks,
- device const char * pad,
- device char * dst,
- threadgroup half * shmem_f16 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
-#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
- switch (FC_flash_attn_ext_vec_nsg) {
- // note: disabled cases to reduce library load time
- case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
- case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
- case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
- //case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
- //case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
- //case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
- }
-#undef FWD_TMPL
-#undef FWD_ARGS
-}
-
// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
//