const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
- for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
+ float4x4 reg_f;
+
+ for (int i = 0; i < 8; i++) {
+ reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
+ reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
}
+
+ reg = (type4x4) reg_f;
}
template <typename type4x4>
const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
- for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
+ float4x4 reg_f;
+
+ for (int i = 0; i < 8; i++) {
+ reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
+ reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
}
+
+ reg = (type4x4) reg_f;
}
template <typename type4x4>
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
+ float4x4 reg_f;
+
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
- reg[i/2][2*(i%2)+0] = d * x0 + md;
- reg[i/2][2*(i%2)+1] = d * x1 + md;
+ reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
+ reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
}
+
+ reg = (type4x4) reg_f;
}
template <typename type4x4>
const int gh_mv = il ? 12 : 0;
const int gh_bk = il ? 0 : 4;
+ float4x4 reg_f;
+
for (int i = 0; i < 8; i++) {
// extract the 5-th bits for x0 and x1
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
- reg[i/2][2*(i%2)+0] = d * x0 + m;
- reg[i/2][2*(i%2)+1] = d * x1 + m;
+ reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
+ reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
}
+
+ reg = (type4x4) reg_f;
}
template <typename type4x4>
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
+ float4x4 reg_f;
+
for (int i = 0; i < 16; i++) {
- reg[i/4][i%4] = (qs[i + 16*il] * d);
+ reg_f[i/4][i%4] = (qs[i + 16*il] * d);
}
+
+ reg = (type4x4) reg_f;
}
template <typename type4x4>
}
// ref: https://arxiv.org/pdf/2307.08691.pdf
-// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
+template<
+ typename q_t, // query types in shared memory
+ typename q4_t,
+ typename q8x8_t,
+ typename k_t, // key types in shared memory
+ typename k4x4_t,
+ typename k8x8_t,
+ typename v_t, // value types in shared memory
+ typename v4x4_t,
+ typename v8x8_t,
+ typename qk_t, // Q*K types
+ typename qk8x8_t,
+ typename s_t, // soft-max types
+ typename s8x8_t,
+ typename o_t, // attention accumulation types
+ typename o4_t,
+ typename o8x8_t,
+ typename kd4x4_t, // key type in device memory
+ short nl_k,
+ void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+ typename vd4x4_t, // key type in device memory
+ short nl_v,
+ void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+ short D, // head size
+ short Q = 8, // queries per threadgroup
+ short KV = 8, // key/value processed per each simdgroup
+ short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant uint64_t & nb21,
- constant uint64_t & nb22,
- constant uint64_t & nb23,
- constant uint64_t & nb31,
- constant int64_t & ne1,
- constant int64_t & ne2,
+ constant int32_t & ne01,
+ constant int32_t & ne02,
+ constant int32_t & ne03,
+ constant uint32_t & nb01,
+ constant uint32_t & nb02,
+ constant uint32_t & nb03,
+ constant int32_t & ne11,
+ constant int32_t & ne_12_2, // assume K and V are same shape
+ constant int32_t & ne_12_3,
+ constant uint32_t & nb_12_1,
+ constant uint32_t & nb_12_2,
+ constant uint32_t & nb_12_3,
+ constant uint32_t & nb31,
+ constant int32_t & ne1,
+ constant int32_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
- constant uint32_t & n_head_log2,
+ constant uint16_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ ushort3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const int iq3 = tgpig[2];
const short D8 = D/8;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
- const short SH = (C + Q); // shared memory per simdgroup in (half)
+ const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
- const short T = D + 2*nsg*SH; // shared memory size per query in (half)
- const short TF = T/2; // shared memory size per query in (float)
- const short T4 = T/4; // shared memory size per query in (half4)
+ const short TS = nsg*SH; // shared memory size per query in (s_t == float)
+ const short T = D + 2*TS; // shared memory size per query in (half)
- threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
- threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
+ threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
+ threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
- threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
- threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
+ threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
+ threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+
+ threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
+ threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
- simdgroup_half8x8 lo[D8];
+ o8x8_t lo[D8];
// load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) {
for (short i = tiisg; i < D4; i += NW) {
if (iq1 + j < ne01) {
- sq4[j*T4 + i] = (half4) q4[i];
+ sq4[j*D4 + i] = (q4_t) q4[i];
} else {
- sq4[j*T4 + i] = 0.0h;
+ sq4[j*D4 + i] = (q4_t) 0.0f;
}
}
}
// zero out lo
for (short i = 0; i < D8; ++i) {
- lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
+ lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
}
// zero out shared memory SH
for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < SH; i += NW) {
- ss[j*TF + i] = 0.0f;
+ ss[j*TS + i] = 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
- float S[Q] = { [0 ... Q-1] = 0.0f };
- float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
+ half S[Q] = { [0 ... Q-1] = 0.0f };
+ half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
// thread indices inside the simdgroup
+ // TODO: see if we can utilize quad-group functions for better performance
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
const short tx = tiisg%4;
const short ty = tiisg/4;
- // assume K and V are same shape
- const short ne22 = ne12;
- const short ne23 = ne13;
-
- // broadcast k
- const short rk2 = ne02/ne12;
- const short rk3 = ne03/ne13;
-
- const short ik2 = iq2/rk2;
- const short ik3 = iq3/rk3;
+ // broadcast kv
+ //const short rk2 = ne02/ne12;
+ //const short rk3 = ne03/ne13;
- // broadcast v
- const short rv2 = ne02/ne22;
- const short rv3 = ne03/ne23;
-
- const short iv2 = iq2/rv2;
- const short iv3 = iq3/rv3;
+ const short ikv2 = iq2/(ne02/ne_12_2);
+ const short ikv3 = iq3/(ne03/ne_12_3);
// load the queries from shared memory into local memory
- simdgroup_half8x8 mq[D8];
+ q8x8_t mq[D8];
for (short i = 0; i < D8; ++i) {
- simdgroup_load(mq[i], sq + i*8, T);
+ simdgroup_load(mq[i], sq + i*8, D);
}
- // pointer to the mask
- device const half * mp = (device const half *) (mask + iq1*nb31);
+ const bool has_mask = mask != q;
- float slope = 1.0f;
+ half slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
- const uint32_t h = iq2;
+ const short h = iq2;
- const float base = h < n_head_log2 ? m0 : m1;
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ const half base = h < n_head_log2 ? m0 : m1;
+ const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exph);
}
break;
}
+ if (has_mask) {
+ // used to detect blocks full of -INF
+ half smax = -INFINITY;
+
+ // load the mask in shared memory
+ for (short j = 0; j < Q; ++j) {
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
+
+ const half m = pm[ic + tiisg];
+
+ ss[j*TS + C + tiisg] = m;
+ smax = max(smax, m);
+ }
+
+ smax = simd_max(smax);
+
+ if (smax == -INFINITY) {
+ continue;
+ }
+ }
+
// Q*K^T
{
for (short cc = 0; cc < C/8; ++cc) {
- simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
+ qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
// this is compile-time check, so it does not have runtime overhead
- if (is_same<block_q, half4x4>::value) {
+ if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory
- device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
+ device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+#pragma unroll
for (short i = 0; i < D8; ++i) {
- simdgroup_half8x8 mk;
- simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
+ k8x8_t mk;
+ simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
- device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
- half4x4 tmp;
- dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
- skv4[4*ty + tx] = tmp;
+ {
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
+ }
simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll
for (short k = 0; k < 4; ++k) {
- simdgroup_half8x8 mk;
+ k8x8_t mk;
- simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
- simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
}
} else {
if (ii + tx < D16) {
- half4x4 tmp;
- dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
- skv4[4*ty + tx] = tmp;
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
}
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) {
- simdgroup_half8x8 mk;
+ k8x8_t mk;
- simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
- simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
}
}
}
}
- simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
+ // cast qk_t -> s_t
+ //s8x8_t mqks(1.0f);
+ //simdgroup_multiply(mqks, mqk, mqks);
+ //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
+
+ simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
}
}
- // used to detect blocks full of -INF
- float smax = -INFINITY;
-
// online softmax
{
- float ms[Q];
-
- for (short j = 0; j < Q; ++j) {
- const float m = M[j];
+ for (ushort j = 0; j < Q; ++j) {
+ const half m = M[j];
// scale and apply the logitcap / mask
- float s = ss[j*TF + tiisg]*scale;
+ half s = ss[j*TS + tiisg]*scale;
if (logit_softcap != 0.0f) {
s = logit_softcap*precise::tanh(s);
}
- if (mask != q) {
- // mqk = mqk + mask*slope
- s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
- }
+ // mqk = mqk + mask*slope
+ s += slope*ss[j*TS + C + tiisg];
- smax = simd_max(max(smax, s));
M[j] = simd_max(max(M[j], s));
- ms[j] = exp(m - M[j]);
- const float vs = exp(s - M[j]);
+ const half ms = exp(m - M[j]);
+ const half vs = exp(s - M[j]);
- S[j] = S[j]*ms[j] + simd_sum(vs);
+ S[j] = S[j]*ms + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
- ss[j*TF + tiisg] = vs;
- }
+ ss[j*TS + tiisg] = vs;
- // create a QxQ diagonal matrix for rescaling the output
- if (tiisg < Q) {
- ss[tiisg*TF + C + tiisg] = ms[tiisg];
+ // create a QxQ diagonal matrix for rescaling the output
+ if (tiisg == j) {
+ ss[j*TS + 2*C + j] = ms;
+ }
}
}
- // skip -INF blocks
- if (smax == -INFINITY) {
- continue;
- }
-
// O = diag(ms)*O
{
- simdgroup_float8x8 mm;
- simdgroup_load(mm, ss + C, TF, 0, false);
+ s8x8_t mm;
+ simdgroup_load(mm, ss + 2*C, TS, 0, false);
+#pragma unroll
for (short i = 0; i < D8; ++i) {
simdgroup_multiply(lo[i], mm, lo[i]);
}
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C/8; ++cc) {
- simdgroup_float8x8 ms;
- simdgroup_load(ms, ss + 8*cc, TF, 0, false);
+ s8x8_t ms;
+ simdgroup_load(ms, ss + 8*cc, TS, 0, false);
- if (is_same<block_q, half4x4>::value) {
+ if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory
- device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
+ device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
#pragma unroll
for (short i = 0; i < D8; ++i) {
- simdgroup_half8x8 mv;
- simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
+ v8x8_t mv;
+ simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
- device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
if (D16%4 == 0) {
// no need for bound checks
- half4x4 tmp;
- dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
- skv4[4*ty + tx] = tmp;
+ {
+ v4x4_t tmp;
+ deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+ sv4x4[4*ty + tx] = tmp;
+ }
simdgroup_barrier(mem_flags::mem_threadgroup);
#pragma unroll
for (short k = 0; k < 4; ++k) {
- simdgroup_half8x8 mv;
+ v8x8_t mv;
- simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
- simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
}
} else {
if (ii + tx < D16) {
- half4x4 tmp;
- dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
- skv4[4*ty + tx] = tmp;
+ v4x4_t tmp;
+ deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
+ sv4x4[4*ty + tx] = tmp;
}
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < D16; ++k) {
- simdgroup_half8x8 mv;
+ v8x8_t mv;
- simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
- simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = 0; j < Q; ++j) {
if (tiisg == 0) {
- ss[j*TF + 0] = S[j];
- ss[j*TF + 1] = M[j];
+ ss[j*TS + 0] = S[j];
+ ss[j*TS + 1] = M[j];
}
}
}
// reduce the warps sequentially
- for (short sg = 1; sg < nsg; ++sg) {
- float S = { 0.0f };
- float M = { -FLT_MAX/2 };
+ for (ushort sg = 1; sg < nsg; ++sg) {
+ half S = { 0.0f };
+ half M = { -__FLT16_MAX__/2 };
threadgroup_barrier(mem_flags::mem_threadgroup);
// each simdgroup stores its output to shared memory, reusing sq
if (sgitg == sg) {
for (short i = 0; i < D8; ++i) {
- simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ simdgroup_store(lo[i], so + i*8, D, 0, false);
}
}
// the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) {
for (short j = 0; j < Q; ++j) {
- const float S0 = ss[j*TF + 0];
- const float S1 = ss[j*TF + sg*SH + 0];
+ const half S0 = ss[j*TS + 0];
+ const half S1 = ss[j*TS + sg*SH + 0];
- const float M0 = ss[j*TF + 1];
- const float M1 = ss[j*TF + sg*SH + 1];
+ const half M0 = ss[j*TS + 1];
+ const half M1 = ss[j*TS + sg*SH + 1];
M = max(M0, M1);
- const float ms0 = exp(M0 - M);
- const float ms1 = exp(M1 - M);
+ const half ms0 = exp(M0 - M);
+ const half ms1 = exp(M1 - M);
S = S0*ms0 + S1*ms1;
if (tiisg == 0) {
- ss[j*TF + 0] = S;
- ss[j*TF + 1] = M;
+ ss[j*TS + 0] = S;
+ ss[j*TS + 1] = M;
- ss[j*TF + C + j ] = ms0;
- ss[j*TF + C + j + sg*SH] = ms1;
+ ss[j*TS + 2*C + j ] = ms0;
+ ss[j*TS + 2*C + j + sg*SH] = ms1;
}
}
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
{
- simdgroup_half8x8 t;
- simdgroup_float8x8 ms0;
- simdgroup_float8x8 ms1;
+ s8x8_t ms0;
+ s8x8_t ms1;
- simdgroup_load(ms0, ss + C, TF, 0, false);
- simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
+ simdgroup_load(ms0, ss + 2*C, TS, 0, false);
+ simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
for (short i = 0; i < D8; ++i) {
- simdgroup_load (t, sq + i*8, T, 0, false);
+ o8x8_t t;
+
+ simdgroup_load (t, so + i*8, D, 0, false);
simdgroup_multiply(t, ms1, t);
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
// store result to shared memory (reuse sq)
if (sgitg == 0) {
for (short i = 0; i < D8; ++i) {
- simdgroup_store(lo[i], sq + i*8, T, 0, false);
+ simdgroup_store(lo[i], so + i*8, D, 0, false);
}
}
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
- const float S = ss[j*TF + 0];
+ const float S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) {
- dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
+ dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
}
}
}
}
-typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
-
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
-
-template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
-
-// NOTE: can use half instead of float precision for some extra perf
-// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
-template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
+// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
+// template to be able to explore different combinations
+//
+#define FA_TYPES \
+ half, half4, simdgroup_half8x8, \
+ half, half4x4, simdgroup_half8x8, \
+ half, half4x4, simdgroup_half8x8, \
+ float, simdgroup_float8x8, \
+ float, simdgroup_float8x8, \
+ half, half4, simdgroup_half8x8
+
+typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
+
+template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>;
+template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80>;
+template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96>;
+template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112>;
+template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
+
+#if !defined(GGML_METAL_NO_BFLOAT)
+template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64>;
+template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80>;
+template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96>;
+template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112>;
+template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
+#endif
+
+template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
+
+#undef FA_TYPES
+
+template<
+ typename q4_t, // query types in shared memory
+ typename q4x4_t,
+ typename k4x4_t, // key types in shared memory
+ typename v4x4_t, // value types in shared memory
+ typename qk_t, // Q*K types
+ typename s_t, // soft-max types
+ typename s4_t,
+ typename s4x4_t,
+ typename o4x4_t, // attention accumulation types
+ typename kd4x4_t, // key type in device memory
+ short nl_k,
+ void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+ typename vd4x4_t, // key type in device memory
+ short nl_v,
+ void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+ short D, // head size
+ short Q = 1, // queries per threadgroup
+ short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant uint64_t & nb21,
- constant uint64_t & nb22,
- constant uint64_t & nb23,
- constant uint64_t & nb31,
- constant int64_t & ne1,
- constant int64_t & ne2,
+ constant int32_t & ne01,
+ constant int32_t & ne02,
+ constant int32_t & ne03,
+ constant uint32_t & nb01,
+ constant uint32_t & nb02,
+ constant uint32_t & nb03,
+ constant int32_t & ne11,
+ constant int32_t & ne_12_2, // assume K and V are same shape
+ constant int32_t & ne_12_3,
+ constant uint32_t & nb_12_1,
+ constant uint32_t & nb_12_2,
+ constant uint32_t & nb_12_3,
+ constant uint32_t & nb31,
+ constant int32_t & ne1,
+ constant int32_t & ne2,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
- constant uint32_t & n_head_log2,
+ constant uint16_t & n_head_log2,
constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ ushort3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const int iq3 = tgpig[2];
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short NW4 = NW/4;
- const short SH = C; // shared memory per simdgroup in (half)
+ const short SH = 2*C; // shared memory per simdgroup
- const short T = D + 2*nsg*SH; // shared memory size per query in (half)
+ const short T = D + nsg*SH; // shared memory size per query in (half)
- //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
- threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
- threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
- threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
- threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
+ //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
+ threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
+ threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
+ threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
- float4x4 lo[D16/NW4];
+ o4x4_t lo[D16/NW4];
// load heads from Q to shared memory
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
for (short i = tiisg; i < D4; i += NW) {
if (iq1 < ne01) {
- sq4[i] = (half4) q4[i];
+ sq4[i] = (q4_t) q4[i];
} else {
- sq4[i] = 0.0h;
+ sq4[i] = (q4_t) 0.0f;
}
}
// zero out lo
for (short i = 0; i < D16/NW4; i += NW4) {
- lo[i] = float4x4(0.0f);
+ lo[i] = (o4x4_t) 0.0f;
}
// zero out shared memory SH
for (short i = tiisg; i < SH/4; i += NW) {
- ss4[i] = 0.0h;
+ ss4[i] = (s4_t) 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
{
- float S = 0.0f;
- float M = -FLT_MAX/2;
+ half S = 0.0f;
+ half M = -__FLT16_MAX__/2;
// thread indices inside the simdgroup
const short tx = tiisg%8;
const short ty = tiisg/8;
- // assume K and V are same shape
- const short ne22 = ne12;
- const short ne23 = ne13;
-
- // broadcast k
- const short rk2 = ne02/ne12;
- const short rk3 = ne03/ne13;
-
- const short ik2 = iq2/rk2;
- const short ik3 = iq3/rk3;
+ // broadcast kv
+ //const short rk2 = ne02/ne12;
+ //const short rk3 = ne03/ne13;
- // broadcast v
- const short rv2 = ne02/ne22;
- const short rv3 = ne03/ne23;
-
- const short iv2 = iq2/rv2;
- const short iv3 = iq3/rv3;
+ const short ikv2 = iq2/(ne02/ne_12_2);
+ const short ikv3 = iq3/(ne03/ne_12_3);
// load the queries from shared memory into local memory
- float4x4 mq[D16/NW4];
+ q4x4_t mq[D16/NW4];
for (short ii = 0; ii < D16; ii += NW4) {
- mq[ii/NW4] = (float4x4) sq44[ii + tx];
+ mq[ii/NW4] = sq4x4[ii + tx];
}
+ const bool has_mask = mask != q;
+
// pointer to the mask
- device const half * mp = (device const half *) (mask + iq1*nb31);
+ device const half * pm = (device const half *) (mask + iq1*nb31);
- float slope = 1.0f;
+ half slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
- const uint32_t h = iq2;
+ const short h = iq2;
- const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ const half base = h < n_head_log2 ? m0 : m1;
+ const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
- slope = pow(base, exp);
+ slope = pow(base, exph);
}
// loop over the KV cache
break;
}
+ if (has_mask) {
+ sm[tiisg] = pm[ic + tiisg];
+ }
+
// Q*K^T
{
// each simdgroup processes 1 query and 4 keys
for (short cc = 0; cc < C/4; ++cc) {
- float mqk = 0.0;
+ qk_t mqk = 0.0;
- device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
+ device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
#pragma unroll
for (short ii = 0; ii < D16; ii += NW4) {
const short i = ii + tx;
- float4x4 mk;
- dequantize_func(pk + i/nl, i%nl, mk);
+ k4x4_t mk;
+ deq_k(pk + i/nl_k, i%nl_k, mk);
mqk +=
dot(mq[ii/NW4][0], mk[0]) +
mqk = logit_softcap*precise::tanh(mqk);
}
- mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
+ mqk += sm[4*cc + ty]*slope;
ss[4*cc + ty] = mqk;
}
// online softmax
{
- const short p = tiisg;
-
- const float m = M;
- const float s = ss[p];
+ const half m = M;
+ const half s = ss[tiisg];
M = simd_max(max(M, s));
- const float ms = exp(m - M);
- const float vs = exp(s - M);
+ const half ms = exp(m - M);
+ const half vs = exp(s - M);
S = S*ms + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
- ss[p] = vs;
+ ss[tiisg] = vs;
// O = diag(ms)*O
#pragma unroll
{
#pragma unroll
for (short cc = 0; cc < C/4; ++cc) {
- device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
+ device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
- const float4x4 lss(ss[4*cc + ty]);
+ const s4x4_t ms(ss[4*cc + ty]);
#pragma unroll
for (short ii = 0; ii < D16; ii += NW4) {
const short i = ii + tx;
- float4x4 mv;
- dequantize_func(pv4 + i/nl, i%nl, mv);
+ v4x4_t mv;
+ deq_v(pv4 + i/nl_v, i%nl_v, mv);
- lo[ii/NW4] += mv*lss;
+ lo[ii/NW4] += mv*ms;
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (tiisg == 0) {
- ss[0] = S;
- ss[1] = M;
+ ss[0] = (s_t) S;
+ ss[1] = (s_t) M;
}
}
// store results to shared memory
for (short i = tiisg; i < D16; i += NW4) {
- sr44[i] = lo[i/NW4];
+ sr4x4[i] = lo[i/NW4];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) {
- const float S0 = ss[ 0];
- const float S1 = ss[r*SH + 0];
+ const half S0 = ss[ 0];
+ const half S1 = ss[r*SH + 0];
- const float M0 = ss[ 1];
- const float M1 = ss[r*SH + 1];
+ const half M0 = ss[ 1];
+ const half M1 = ss[r*SH + 1];
- const float M = max(M0, M1);
+ const half M = max(M0, M1);
- const float ms0 = exp(M0 - M);
- const float ms1 = exp(M1 - M);
+ const half ms0 = exp(M0 - M);
+ const half ms1 = exp(M1 - M);
- const float S = S0*ms0 + S1*ms1;
+ const half S = S0*ms0 + S1*ms1;
if (tiisg == 0) {
ss[0] = S;
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short i = tiisg; i < D16; i += NW) {
- sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
+ sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
}
}
const float S = ss[0];
for (short i = tiisg; i < D16; i += NW) {
- dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
+ dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
}
}
}
-typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
+// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
+// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
+//
+#define FA_TYPES \
+ half4, half4x4, \
+ half4x4, \
+ half4x4, \
+ float, \
+ half, half4, half4x4, \
+ half4x4
+
+typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
+#if !defined(GGML_METAL_NO_BFLOAT)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
+
+template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
+#if !defined(GGML_METAL_NO_BFLOAT)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
+#undef FA_TYPES
template<typename T0, typename T1>
kernel void kernel_cpy(