threadgroup_barrier(mem_flags::mem_threadgroup);
{
- half S[Q] = { [0 ... Q-1] = 0.0f };
- half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
+ float S[Q] = { [0 ... Q-1] = 0.0f };
+ float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
// thread indices inside the simdgroup
// TODO: see if we can utilize quad-group functions for better performance
const bool has_mask = mask != q;
- half slope = 1.0f;
+ float slope = 1.0f;
// ALiBi
if (args.max_bias > 0.0f) {
const short h = iq2;
- const half base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph);
if (has_mask) {
// used to detect blocks full of -INF
- half smax = -INFINITY;
+ float smax = -INFINITY;
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
- const half m = pm[ic + tiisg];
+ const float m = pm[ic + tiisg];
ss[j*TS + C + tiisg] = m;
smax = max(smax, m);
// online softmax
{
for (ushort j = 0; j < Q; ++j) {
- const half m = M[j];
+ const float m = M[j];
// scale and apply the logitcap / mask
- half s = ss[j*TS + tiisg]*args.scale;
+ float s = ss[j*TS + tiisg]*args.scale;
if (args.logit_softcap != 0.0f) {
s = args.logit_softcap*precise::tanh(s);
M[j] = simd_max(max(M[j], s));
- const half ms = exp(m - M[j]);
- const half vs = exp(s - M[j]);
+ const float ms = exp(m - M[j]);
+ const float vs = exp(s - M[j]);
S[j] = S[j]*ms + simd_sum(vs);
// reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) {
- half S = { 0.0f };
- half M = { -__FLT16_MAX__/2 };
+ float S = { 0.0f };
+ float M = { -__FLT16_MAX__/2 };
threadgroup_barrier(mem_flags::mem_threadgroup);
// the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) {
for (short j = 0; j < Q; ++j) {
- const half S0 = ss[j*TS + 0];
- const half S1 = ss[j*TS + sg*SH + 0];
+ const float S0 = ss[j*TS + 0];
+ const float S1 = ss[j*TS + sg*SH + 0];
- const half M0 = ss[j*TS + 1];
- const half M1 = ss[j*TS + sg*SH + 1];
+ const float M0 = ss[j*TS + 1];
+ const float M1 = ss[j*TS + sg*SH + 1];
M = max(M0, M1);
- const half ms0 = exp(M0 - M);
- const half ms1 = exp(M1 - M);
+ const float ms0 = exp(M0 - M);
+ const float ms1 = exp(M1 - M);
S = S0*ms0 + S1*ms1;
constexpr short DV4 = DV/4;
constexpr short NW = N_SIMDWIDTH;
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
- constexpr short SH = 2*C; // shared memory per simdgroup
+ constexpr short SH = 4*C; // shared memory per simdgroup
const short T = DK + nsg*SH; // shared memory size per query in (half)
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
- threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
// store the result for all queries in local memory (the O matrix from the paper)
o4_t lo[DV4/NL];
threadgroup_barrier(mem_flags::mem_threadgroup);
{
- half S = 0.0f;
- half M = -__FLT16_MAX__/2;
+ float S = 0.0f;
+ float M = -__FLT16_MAX__/2;
// thread indices inside the simdgroup
const short tx = tiisg%NL;
// pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31);
- half slope = 1.0f;
+ float slope = 1.0f;
// ALiBi
if (args.max_bias > 0.0f) {
const short h = iq2;
- const half base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph);
// online softmax
{
- const half m = M;
- const half s = ss[tiisg];
+ const float m = M;
+ const float s = ss[tiisg];
M = simd_max(max(M, s));
- const half ms = exp(m - M);
- const half vs = exp(s - M);
+ const float ms = exp(m - M);
+ const float vs = exp(s - M);
S = S*ms + simd_sum(vs);
v4_t mv;
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
- lo[ii/NL] += mv*ms;
+ lo[ii/NL] += o4_t(float4(mv)*float4(ms));
}
}
}
// parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) {
- const half S0 = ss[ 0];
- const half S1 = ss[r*SH + 0];
+ const float S0 = ss[ 0];
+ const float S1 = ss[r*(SH/2) + 0];
- const half M0 = ss[ 1];
- const half M1 = ss[r*SH + 1];
+ const float M0 = ss[ 1];
+ const float M1 = ss[r*(SH/2) + 1];
- const half M = max(M0, M1);
+ const float M = max(M0, M1);
- const half ms0 = exp(M0 - M);
- const half ms1 = exp(M1 - M);
+ const float ms0 = exp(M0 - M);
+ const float ms1 = exp(M1 - M);
- const half S = S0*ms0 + S1*ms1;
+ const float S = S0*ms0 + S1*ms1;
if (tiisg == 0) {
ss[0] = S;
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
//
#define FA_TYPES \
- half4, \
- half4, \
- half4, \
- float, \
- half, half4, \
+ half4, \
+ half4, \
+ half4, \
+ float, \
+ float, float4, \
half4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;