]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: add attention sinks for tile and wmma (llama/15178)
authorAman Gupta <redacted>
Sat, 9 Aug 2025 12:00:24 +0000 (20:00 +0800)
committerGeorgi Gerganov <redacted>
Thu, 14 Aug 2025 11:17:28 +0000 (14:17 +0300)
* CUDA: add attention sinks for tile and wmma

* Review: formatting changes + remove syncthreads from tile + remove warp_reduce_max from wmma

src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-wmma-f16.cu
src/ggml-cuda/fattn.cu

index 0fcfaa32ea4645ef924cb2b961845fbd9e51cfec..1e23f8f79c202879c02735baa8d21767445749af 100644 (file)
@@ -49,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16(
     const int sequence = blockIdx.z / ne02;
     const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
+    const float2 * Q_f2   = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2   = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2   = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh  = (const half   *) (mask  + nb33*(sequence % ne33)                          + nb31*ic0);
+    const float  * sinksf = (const float  *) (sinks);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
@@ -242,6 +243,31 @@ static __global__ void flash_attn_tile_ext_f16(
         __syncthreads();
     }
 
+    //Attention sink: adjust running max and sum once per head
+    if (sinksf && blockIdx.y == 0) {
+        const half sink = __float2half(sinksf[head]);
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
+            kqmax[j0/nwarps] = kqmax_new_j;
+
+            const half val = hexp(sink - kqmax[j0/nwarps]);
+            kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
+            if (threadIdx.x == 0) {
+                kqsum[j0/nwarps].x = __hadd(kqsum[j0/nwarps].x, val);
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
+            }
+        }
+    }
+
     float2 * dst2 = (float2 *) dst;
 
 #pragma unroll
index 23550cbbd9736a9410cd137d66b057d953928529..c58194937d7a633cb6687b7c604c2c6c4000ae06 100644 (file)
@@ -60,10 +60,11 @@ static __global__ void flash_attn_tile_ext_f32(
     const int sequence = blockIdx.z / ne02;
     const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float2 * Q_f2  = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
-    const half2  * K_h2  = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
-    const half2  * V_h2  = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
-    const half   * maskh = (const half   *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
+    const float2 * Q_f2   = (const float2 *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half2  * K_h2   = (const half2  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half2  * V_h2   = (const half2  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half   * maskh  = (const half   *) (mask  + nb33*(sequence % ne33)                          + nb31*ic0);
+    const float  * sinksf = (const float  *) (sinks);
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
@@ -252,6 +253,33 @@ static __global__ void flash_attn_tile_ext_f32(
         __syncthreads();
     }
 
+
+    //Attention sink: adjust running max and sum once per head
+    if (sinksf && blockIdx.y == 0) {
+        const float sink = sinksf[head];
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
+            kqmax[j0/nwarps] = kqmax_new_j;
+
+            const float val = expf(sink - kqmax[j0/nwarps]);
+            kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
+            if (threadIdx.x == 0) {
+                kqsum[j0/nwarps] += val;
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
+                VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
+            }
+        }
+    }
+
     float2 * dst2 = (float2 *) dst;
 
 #pragma unroll
index 93d4d810218d78a6560b5b9f7a97663fb806d833..fdc4d17da2da962e64afd3708e9101afa8c38964 100644 (file)
@@ -82,11 +82,12 @@ static __global__ void flash_attn_ext_f16(
     const int sequence = blockIdx.z / ne02;
     const int head = blockIdx.z - sequence*ne02;
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
-    const half  * V_h   = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
-    const half2 * mask2 = (const half2 *)  maskh;
+    const float * Q_f    = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0);
+    const half  * K_h    = (const half  *) (K    + nb13* sequence         + nb12*(head / gqa_ratio));
+    const half  * V_h    = (const half  *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape
+    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0);
+    const half2 * mask2  = (const half2 *)  maskh;
+    const float * sinksf = (const float *) sinks;
 
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
@@ -381,6 +382,53 @@ static __global__ void flash_attn_ext_f16(
         __syncthreads();
     }
 
+    // Apply attention sinks
+    if (sinksf && blockIdx.y == 0) {
+        const float sinkf = sinksf[head];
+        const half  sinkh = __float2half(sinkf);
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same<KQ_acc_t, float>::value) {
+                float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
+
+                const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
+                KQ_max_f[j0/nwarps] = kqmax_new;
+
+                KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
+
+                const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+                    const int i = i0 + threadIdx.x;
+                    if (i0 + warp_size > D/2 && i >= D/2) break;
+                    VKQ2[j*(D_padded/2) + i] *= scale_h2;
+                }
+            } else {
+                half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
+                half kqmax_new = fmaxf(kqmax_old, sinkh);
+                KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
+
+                const half  KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
+                const half2 KQ_max_scale   = __half2half2(KQ_max_scale_h);
+
+                KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
+                const half val = hexp(sinkh - kqmax_new);
+                KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
+
+#pragma unroll
+                for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+                    const int i = i0 + threadIdx.x;
+                    if (i0 + warp_size > D/2 && i >= D/2) break;
+                    VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
+                }
+            }
+        }
+
+        __syncthreads();
+    }
 #pragma unroll
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j_VKQ = j0 + threadIdx.y;
index 6c1185deac8507b212cda3f8dce5a7c6339e85fd..22e90d0e7b31611da2c55e2849f91609e0ee49bc 100644 (file)
@@ -274,23 +274,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const ggml_tensor * K     = dst->src[1];
     const ggml_tensor * V     = dst->src[2];
     const ggml_tensor * mask  = dst->src[3];
-    const ggml_tensor * sinks = dst->src[4];
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-    // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
-    if (sinks && !fp16_mma_available(cc)) {
-        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
-            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
-        } else {
-            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-        }
-        return;
-    }
-
 #if defined(GGML_HIP_ROCWMMA_FATTN)
     if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);