threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
// O = diag(ms)*O
{
- s8x8_t mm;
- simdgroup_load(mm, ss + 2*C, TS, 0, false);
+ s8x8_t ms;
+ simdgroup_load(ms, ss + 2*C, TS, 0, false);
#pragma unroll(DV8)
for (short i = 0; i < DV8; ++i) {
- simdgroup_multiply(lo[i], mm, lo[i]);
+ simdgroup_multiply(lo[i], ms, lo[i]);
}
}
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C/8; ++cc) {
- s8x8_t ms;
- simdgroup_load(ms, ss + 8*cc, TS, 0, false);
+ s8x8_t vs;
+ simdgroup_load(vs, ss + 8*cc, TS, 0, false);
if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory
v8x8_t mv;
simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
- simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
+ simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
}
} else {
for (short ii = 0; ii < DV16; ii += 4) {
v8x8_t mv;
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
}
} else {
if (ii + tx < DV16) {
v8x8_t mv;
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
+ simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
}
}
}
}
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
- for (short j = 0; j < Q; ++j) {
- if (tiisg == 0) {
- ss[j*TS + 0] = S[j];
- ss[j*TS + 1] = M[j];
- }
+ for (short j = tiisg; j < Q; j += NW) {
+ ss[j*TS + 0] = S[j];
+ ss[j*TS + 1] = M[j];
}
}
- // reduce the warps sequentially
- for (ushort sg = 1; sg < nsg; ++sg) {
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- // each simdgroup stores its output to shared memory, reusing sq
- if (sgitg == sg) {
- for (short i = 0; i < DV8; ++i) {
- simdgroup_store(lo[i], so + i*8, DV, 0, false);
- }
+ threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
+ threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
+
+ // store result to shared memory in F32
+ if (sgitg == 0) {
+ for (short i = 0; i < DV8; ++i) {
+ //simdgroup_store(lo[i], so + i*8, DV, 0, false);
+ simdgroup_float8x8 t(1.0f);
+ simdgroup_multiply(t, lo[i], t);
+ simdgroup_store(t, so + i*8, DV, 0, false);
}
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- // the first simdgroup accumulates the results from the other simdgroups
- if (sgitg == 0) {
- for (short j = 0; j < Q; ++j) {
- const float S0 = ss[j*TS + 0];
- const float S1 = ss[j*TS + sg*SH + 0];
+ // reduce the warps sequentially
+ for (ushort sg = 1; sg < nsg; ++sg) {
+ if (sgitg == sg) {
+ for (short j = tiisg; j < Q; j += NW) {
+ const float S0 = ss[j*TS - 1*SH + 0];
+ const float S1 = ss[j*TS + 0];
- const float M0 = ss[j*TS + 1];
- const float M1 = ss[j*TS + sg*SH + 1];
+ const float M0 = ss[j*TS - 1*SH + 1];
+ const float M1 = ss[j*TS + 1];
const float M = max(M0, M1);
- const float ms0 = exp(M0 - M);
- const float ms1 = exp(M1 - M);
+ float ms0 = exp(M0 - M);
+ float ms1 = exp(M1 - M);
const float S = S0*ms0 + S1*ms1;
- if (tiisg == 0) {
- ss[j*TS + 0] = S;
- ss[j*TS + 1] = M;
+ ss[j*TS + 0] = S;
+ ss[j*TS + 1] = M;
- ss[j*TS + 2*C + j ] = ms0;
- ss[j*TS + 2*C + j + sg*SH] = ms1;
- }
+ ss[j*TS + 2*C + j - 1*SH] = ms0;
+ ss[j*TS + 2*C + j ] = ms1;
}
+ //simdgroup_barrier(mem_flags::mem_threadgroup);
+
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
{
s8x8_t ms0;
s8x8_t ms1;
- simdgroup_load(ms0, ss + 2*C, TS, 0, false);
- simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
+ simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
+ simdgroup_load(ms1, ss + 2*C, TS, 0, false);
#pragma unroll(DV8)
for (short i = 0; i < DV8; ++i) {
- o8x8_t t;
+ simdgroup_float8x8 t;
simdgroup_load (t, so + i*8, DV, 0, false);
- simdgroup_multiply(t, ms1, t);
+ simdgroup_multiply(t, ms0, t);
- simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
+ simdgroup_multiply_accumulate(t, ms1, lo[i], t);
+ simdgroup_store(t, so + i*8, DV, 0, false);
}
}
}
- }
- // store result to shared memory (reuse sq)
- if (sgitg == 0) {
- for (short i = 0; i < DV8; ++i) {
- simdgroup_store(lo[i], so + i*8, DV, 0, false);
- }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
}
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
// final rescale with 1/S and store to global memory
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \
float, simdgroup_float8x8, \
- float, float4, simdgroup_float8x8
- //half, half4, simdgroup_half8x8
+ half, half4, simdgroup_half8x8
+ //float, float4, simdgroup_float8x8
#define FA_TYPES_BF \
bfloat, bfloat4, simdgroup_bfloat8x8, \
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
float, simdgroup_float8x8, \
float, simdgroup_float8x8, \
- float, float4, simdgroup_float8x8
- //half, half4, simdgroup_half8x8
+ half, half4, simdgroup_half8x8
+ //float, float4, simdgroup_float8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;