const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
- const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
- const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
- const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
__syncthreads();
}
+ //Attention sink: adjust running max and sum once per head
+ if (sinksf && blockIdx.y == 0) {
+ const half sink = __float2half(sinksf[head]);
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+ const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
+ kqmax[j0/nwarps] = kqmax_new_j;
+
+ const half val = hexp(sink - kqmax[j0/nwarps]);
+ kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
+ if (threadIdx.x == 0) {
+ kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
+ }
+ }
+ }
+
float2 * dst2 = (float2 *) dst;
#pragma unroll
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
- const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
- const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
- const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
__syncthreads();
}
+
+ //Attention sink: adjust running max and sum once per head
+ if (sinksf && blockIdx.y == 0) {
+ const float sink = sinksf[head];
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+ const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
+ kqmax[j0/nwarps] = kqmax_new_j;
+
+ const float val = expf(sink - kqmax[j0/nwarps]);
+ kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
+ if (threadIdx.x == 0) {
+ kqsum[j0/nwarps] += val;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
+ VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
+ }
+ }
+ }
+
float2 * dst2 = (float2 *) dst;
#pragma unroll
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
- const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
- const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
- const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
- const half2 * mask2 = (const half2 *) maskh;
+ const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
+ const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
+ const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const half2 * mask2 = (const half2 *) maskh;
+ const float * sinksf = (const float *) sinks;
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
__syncthreads();
}
+ // Apply attention sinks
+ if (sinksf && blockIdx.y == 0) {
+ const float sinkf = sinksf[head];
+ const half sinkh = __float2half(sinkf);
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (std::is_same<KQ_acc_t, float>::value) {
+ float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
+
+ const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
+ KQ_max_f[j0/nwarps] = kqmax_new;
+
+ KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
+
+ const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + warp_size > D/2 && i >= D/2) break;
+ VKQ2[j*(D_padded/2) + i] *= scale_h2;
+ }
+ } else {
+ half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
+ half kqmax_new = fmaxf(kqmax_old, sinkh);
+ KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
+
+ const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
+ const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
+
+ KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
+ const half val = hexp(sinkh - kqmax_new);
+ KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + warp_size > D/2 && i >= D/2) break;
+ VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
+ }
+ }
+ }
+
+ __syncthreads();
+ }
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
- const ggml_tensor * sinks = dst->src[4];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
- // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
- if (sinks && !fp16_mma_available(cc)) {
- if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
- ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
- } else {
- ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
- }
- return;
- }
-
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);