]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: add attn sinks support for FA kernels (#15706)
authorrmatif <redacted>
Tue, 2 Sep 2025 06:26:53 +0000 (08:26 +0200)
committerGitHub <redacted>
Tue, 2 Sep 2025 06:26:53 +0000 (23:26 -0700)
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/flash_attn_f16.cl
ggml/src/ggml-opencl/kernels/flash_attn_f32.cl
ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl

index c25c2daaf60ea45d0afbd0745da6ebd0c0c481a0..a9a91ca585b2da10b195b2202a9709bc8fa44365 100644 (file)
@@ -2776,10 +2776,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
         case GGML_OP_FLASH_ATTN_EXT:
             {
-                if (op->src[4]) {
-                    return false;
-                }
-
                 const ggml_tensor * q = op->src[0];
                 const ggml_tensor * k = op->src[1];
                 const ggml_tensor * v = op->src[2];
@@ -5765,6 +5761,7 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
 static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
     const ggml_tensor * v = dst->src[2];
     const ggml_tensor * mask = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
     GGML_ASSERT(q->extra);
     GGML_ASSERT(k->extra);
     GGML_ASSERT(v->extra);
@@ -5772,6 +5769,9 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
     if (mask) {
         GGML_ASSERT(mask->extra);
     }
+    if (sinks) {
+        GGML_ASSERT(sinks->extra);
+    }
 
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
@@ -5813,6 +5813,7 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
     ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
     ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
     ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
+    ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;
 
     cl_ulong offset_q = extra_q->offset + q->view_offs;
     cl_ulong offset_k = extra_k->offset + k->view_offs;
@@ -5820,6 +5821,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
     cl_ulong offset_o = extra_o->offset + dst->view_offs;
     cl_mem   mask_buffer = extra_mask ? extra_mask->data_device : NULL;
     cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
+    cl_mem   sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;
+    cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;
 
     const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
     const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
@@ -5874,6 +5877,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
     CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
     CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int),      &mask_ne2));
     CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int),      &mask_ne3));
+    CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem),   &sinks_buffer));
+    CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));
 
     if (n_q == 1) {
         const size_t wg_size = 64;
index fea06867e1020842c8bc0db376f98c7e627c7066..8f43c4f27d58caa685cdfd967c59fed426f3bd2b 100644 (file)
@@ -49,7 +49,9 @@ __kernel void flash_attn_f16(
     const ulong mask_nb2,
     const ulong mask_nb3,
     const int mask_ne2,
-    const int mask_ne3
+    const int mask_ne3,
+    const global void* sinks_void,
+    const ulong sinks_offset
 ) {
     const int tid = get_local_id(0);
     const int block_q_idx = get_group_id(0);
@@ -171,6 +173,20 @@ __kernel void flash_attn_f16(
     }
 
     if (my_query_row < n_q) {
+        if (sinks_void != NULL) {
+            const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+            const ACC_TYPE m_sink = sinks_ptr[head_idx];
+            const ACC_TYPE m_final = max(m_i, m_sink);
+
+            const ACC_TYPE scale_o = exp(m_i - m_final);
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_acc[i] *= scale_o;
+            }
+
+            l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
+        }
+
         const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
         global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
         if (l_i > 0.0f) {
@@ -214,7 +230,9 @@ __kernel void flash_attn_f16_q1(
     const ulong mask_nb2,
     const ulong mask_nb3,
     const int mask_ne2,
-    const int mask_ne3
+    const int mask_ne3,
+    const global void* sinks_void,
+    const ulong sinks_offset
 ) {
     const int tid = get_local_id(0);
     const int head_batch_idx = get_global_id(1);
@@ -247,7 +265,12 @@ __kernel void flash_attn_f16_q1(
 
     float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
 
-    ACC_TYPE m_i = -INFINITY;
+    const global ACC_TYPE* sinks_ptr = NULL;
+    if (sinks_void != NULL) {
+        sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+    }
+
+    ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
     for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
         const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
         const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@@ -320,7 +343,11 @@ __kernel void flash_attn_f16_q1(
 
     const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
     global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
-    const ACC_TYPE l_final = local_l[0];
+    ACC_TYPE l_final = local_l[0];
+
+    if (sinks_ptr != NULL) {
+        l_final += exp(sinks_ptr[head_idx] - m_final);
+    }
 
     if (l_final > 0.0f) {
         const ACC_TYPE l_inv = 1.0f / l_final;
index 2d657327d6460476262358b19277dcd55cd0b41b..9c0bab135a912a7dc931c05915a452f990a80e91 100644 (file)
@@ -49,7 +49,9 @@ __kernel void flash_attn_f32(
     const ulong mask_nb2,
     const ulong mask_nb3,
     const int mask_ne2,
-    const int mask_ne3
+    const int mask_ne3,
+    const global void* sinks_void,
+    const ulong sinks_offset
 ) {
     const int tid = get_local_id(0);
     const int block_q_idx = get_group_id(0);
@@ -171,6 +173,20 @@ __kernel void flash_attn_f32(
     }
 
     if (my_query_row < n_q) {
+        if (sinks_void != NULL) {
+            const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+            const ACC_TYPE m_sink = sinks_ptr[head_idx];
+            const ACC_TYPE m_final = max(m_i, m_sink);
+
+            const ACC_TYPE scale_o = exp(m_i - m_final);
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_acc[i] *= scale_o;
+            }
+
+            l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
+        }
+
         const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
         global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
         if (l_i > 0.0f) {
@@ -214,7 +230,9 @@ __kernel void flash_attn_f32_q1(
     const ulong mask_nb2,
     const ulong mask_nb3,
     const int mask_ne2,
-    const int mask_ne3
+    const int mask_ne3,
+    const global void* sinks_void,
+    const ulong sinks_offset
 ) {
     const int tid = get_local_id(0);
     const int head_batch_idx = get_global_id(1);
@@ -247,7 +265,12 @@ __kernel void flash_attn_f32_q1(
 
     float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
 
-    ACC_TYPE m_i = -INFINITY;
+    const global ACC_TYPE* sinks_ptr = NULL;
+    if (sinks_void != NULL) {
+        sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+    }
+
+    ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
     for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
         const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
         const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@@ -320,7 +343,11 @@ __kernel void flash_attn_f32_q1(
 
     const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
     global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
-    const ACC_TYPE l_final = local_l[0];
+    ACC_TYPE l_final = local_l[0];
+
+    if (sinks_ptr != NULL) {
+        l_final += exp(sinks_ptr[head_idx] - m_final);
+    }
 
     if (l_final > 0.0f) {
         const ACC_TYPE l_inv = 1.0f / l_final;
index 7067bd2591fa7fe09eaf7bc063f303885b1cde7d..ec7361b9e3709c512fbc75abf009bce9e3ad9c97 100644 (file)
@@ -52,7 +52,9 @@ __kernel void flash_attn_f32_f16(
     const ulong mask_nb2,
     const ulong mask_nb3,
     const int mask_ne2,
-    const int mask_ne3
+    const int mask_ne3,
+    const global void* sinks_void,
+    const ulong sinks_offset
 ) {
     const int tid = get_local_id(0);
     const int block_q_idx = get_group_id(0);
@@ -174,6 +176,20 @@ __kernel void flash_attn_f32_f16(
     }
 
     if (my_query_row < n_q) {
+        if (sinks_void != NULL) {
+            const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+            const ACC_TYPE m_sink = sinks_ptr[head_idx];
+            const ACC_TYPE m_final = max(m_i, m_sink);
+
+            const ACC_TYPE scale_o = exp(m_i - m_final);
+            #pragma unroll
+            for (int i = 0; i < DV_VEC; ++i) {
+                o_acc[i] *= scale_o;
+            }
+
+            l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
+        }
+
         const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
         global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
         if (l_i > 0.0f) {
@@ -217,7 +233,9 @@ __kernel void flash_attn_f32_f16_q1(
     const ulong mask_nb2,
     const ulong mask_nb3,
     const int mask_ne2,
-    const int mask_ne3
+    const int mask_ne3,
+    const global void* sinks_void,
+    const ulong sinks_offset
 ) {
     const int tid = get_local_id(0);
     const int head_batch_idx = get_global_id(1);
@@ -250,7 +268,12 @@ __kernel void flash_attn_f32_f16_q1(
 
     float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
 
-    ACC_TYPE m_i = -INFINITY;
+    const global ACC_TYPE* sinks_ptr = NULL;
+    if (sinks_void != NULL) {
+        sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
+    }
+
+    ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
     for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
         const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
         const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
@@ -323,7 +346,11 @@ __kernel void flash_attn_f32_f16_q1(
 
     const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
     global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
-    const ACC_TYPE l_final = local_l[0];
+    ACC_TYPE l_final = local_l[0];
+
+    if (sinks_ptr != NULL) {
+        l_final += exp(sinks_ptr[head_idx] - m_final);
+    }
 
     if (l_final > 0.0f) {
         const ACC_TYPE l_inv = 1.0f / l_final;