const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half2 * const __restrict__ mask_h2,
+ const float * const __restrict__ sinks_f,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
}
}
+ // If attention sinks are used, potentially re-scale if KQ_max is small.
+ // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
+ // so it's being done unconditionally for every thread.
+ if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
+ float KQ_max_scale[cols_per_thread];
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
+ const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
+ const float sink = sinks_f[jc % ncols2];
+
+ const float KQ_max_new = fmaxf(KQ_max[col], sink);
+ const float KQ_max_diff = KQ_max[col] - KQ_max_new;
+ KQ_max_scale[col] = expf(KQ_max_diff);
+ KQ_max[col] = KQ_max_new;
+
+ *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
+
+ const float KQ_max_add = expf(sink - KQ_max_new);
+ KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
+ }
+
+ if (ntiles == 1) {
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+#pragma unroll
+ for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
+#pragma unroll
+ for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+ } else {
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+ for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
+#pragma unroll
+ for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
+ VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+ }
+ }
+ }
+ }
+ }
+
// Combine VKQ accumulator values if np > 1.
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
- const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
+ const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
+ const int head0 = zt * ncols2;
+
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+ const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
}
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
- const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
+ const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
+ const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+
+ const int head0 = zt * ncols2;
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+ const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
int kb0_stop_kernel = kb0_stop * kb_niter;
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+ (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);