]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: attention sinks for mma FlashAttention (llama/15157)
authorJohannes Gäßler <redacted>
Fri, 8 Aug 2025 06:19:58 +0000 (08:19 +0200)
committerGeorgi Gerganov <redacted>
Thu, 14 Aug 2025 11:17:28 +0000 (14:17 +0300)
src/ggml-cuda/fattn-mma-f16.cuh
src/ggml-cuda/fattn.cu
src/ggml-cuda/ggml-cuda.cu

index 3712538441719b0158bbe74cb17c420a4e1d45c2..39731baaeb7f42fb105b7bc0f9a37803badc4a92 100644 (file)
@@ -785,6 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const half2  * const __restrict__ K_h2,
         const half2  * const __restrict__ V_h2,
         const half2  * const __restrict__ mask_h2,
+        const float  * const __restrict__ sinks_f,
         float2       * const __restrict__ dstk,
         float2       * const __restrict__ dstk_fixup,
         const float scale,
@@ -957,6 +958,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
     }
 
+    // If attention sinks are used, potentially re-scale if KQ_max is small.
+    // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
+    //     so it's being done unconditionally for every thread.
+    if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
+        float KQ_max_scale[cols_per_thread];
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+            static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
+            const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
+            const float sink = sinks_f[jc % ncols2];
+
+            const float KQ_max_new = fmaxf(KQ_max[col], sink);
+            const float KQ_max_diff = KQ_max[col] - KQ_max_new;
+            KQ_max_scale[col] = expf(KQ_max_diff);
+            KQ_max[col] = KQ_max_new;
+
+            *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
+
+            const float KQ_max_add = expf(sink - KQ_max_new);
+            KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
+        }
+
+        if (ntiles == 1) {
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+#pragma unroll
+            for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
+#pragma unroll
+                for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+                    VKQ_C[i].x[l] *= KQ_max_scale_h2;
+                }
+            }
+        } else {
+#pragma unroll
+            for (int col = 0; col < cols_per_thread; ++col) {
+                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+                for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
+#pragma unroll
+                    for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
+                        VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+                    }
+                }
+            }
+        }
+    }
+
     // Combine VKQ accumulator values if np > 1.
     // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
     // So also write VKQ accumulators to shared memory in column-major format if np == 1.
@@ -1271,18 +1318,21 @@ static __global__ void flash_attn_ext_f16(
 
     while (kbc < kbc_stop && kb0_stop == iter_k) {
         const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
-        const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
-        const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
+        const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
+        const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
 
-        const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
-        const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
+        const int head0 = zt * ncols2;
+
+        const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02* head0);
+        const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
         const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
             (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
-        float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
+        float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
+        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+        const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
-        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
 
         const int kb0_start_kernel = kb0_start * kb_niter;
         int       kb0_stop_kernel  = kb0_stop  * kb_niter;
@@ -1295,12 +1345,12 @@ static __global__ void flash_attn_ext_f16(
         if (kb0_start == 0) {
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
             flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
-                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
                  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
             flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
-                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
                  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         }
 
@@ -1316,18 +1366,21 @@ static __global__ void flash_attn_ext_f16(
     }
 
     const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
-    const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
-    const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
+    const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
+    const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+
+    const int head0 = zt * ncols2;
 
-    const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
-    const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
+    const float2 * Q_f2    = (const float2 *) (Q + nb03*sequence + nb02* head0);
+    const half2  * K_h2    = (const half2  *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
     const half2  * mask_h2 = ncols2 == 1 && !mask ? nullptr :
         (const half2  *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
-    float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
+    float2       * dstk    = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
+    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+    const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
-    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
 
     const int kb0_start_kernel = kb0_start * kb_niter;
     int       kb0_stop_kernel  = kb0_stop  * kb_niter;
@@ -1339,7 +1392,7 @@ static __global__ void flash_attn_ext_f16(
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
     flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
-        (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+        (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
          ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
index 8ddd0415b7f8f2a4f6a7cbbc20c71952d21d5f51..6c1185deac8507b212cda3f8dce5a7c6339e85fd 100644 (file)
@@ -282,7 +282,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
     // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
-    if (sinks) {
+    if (sinks && !fp16_mma_available(cc)) {
         if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         } else {
index ec7ab255188fcdbbbf84b4929b4e2eb58bb10d45..19e9c405ea2ee0f9fb5f08661bcf427f6c41e6d0 100644 (file)
@@ -3532,7 +3532,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
             }
             // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
-            if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
+            if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
+                    && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
                 return false;
             }
             if (op->src[0]->ne[0] == 192) {