]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: support sink in `soft_max` (attn sinks) (#15152)
authorlhez <redacted>
Fri, 8 Aug 2025 04:47:03 +0000 (13:47 +0900)
committerGitHub <redacted>
Fri, 8 Aug 2025 04:47:03 +0000 (21:47 -0700)
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/softmax_4_f16.cl
ggml/src/ggml-opencl/kernels/softmax_4_f32.cl
ggml/src/ggml-opencl/kernels/softmax_f16.cl
ggml/src/ggml-opencl/kernels/softmax_f32.cl

index 4f765ab53092138a3c9979baca0a65d7803f22e8..b32d5da307545ba6074bd5a4ad4434f40d9b060c 100644 (file)
@@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
         case GGML_OP_CLAMP:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SOFT_MAX:
-            // TODO: support attention sinks [TAG_ATTN_SINKS]
-            return op->src[2] == nullptr;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
             return true;
@@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
         GGML_ASSERT(src1->extra);
     }
 
+    const ggml_tensor * src2 = dst->src[2];
+    if (src2) {
+        GGML_ASSERT(src2->extra);
+    }
+
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
 
     ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
+    ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
 
     cl_ulong offset0 = extra0->offset + src0->view_offs;
     cl_ulong offsetd = extrad->offset + dst->view_offs;
 
     cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
+    cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
 
     const int ne00 = src0->ne[0];
     const int ne01 = src0->ne[1];
@@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
     CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
     CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   extra1 ? &extra1->data_device : &extra0->data_device));
     CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
-    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
-    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
-    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
-    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
-    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
-    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne13));
-    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
-    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
-    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
-    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
-    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
-    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
-    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float),    &scale));
-    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float),    &max_bias));
-    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float),    &m0));
-    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float),    &m1));
-    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &n_head_log2));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   extra2 ? &extra2->data_device : &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne13));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float),    &scale));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float),    &max_bias));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float),    &m0));
+    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),    &m1));
+    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &n_head_log2));
 
     size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
     size_t local_work_size[] = {(size_t)nth, 1, 1};
index a6d8ede67010db92a11133d14e9bb366c406e0dd..571d16507c6f31a6393b6b6aadc19891255eb5ef 100644 (file)
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16(
 
     global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global half4  * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float  * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
     global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16(
     }
 
     // parallel max
-    float4 lmax4 = -INFINITY;
+    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f));
     }
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16(
     }
     float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         pdst4[i00] /= sum;
index 35b5573b46a81bd893dcdf1d9ac2e0193a3fcffd..1f944b2201d5ab0199a01350b3da45769d08fc4d 100644 (file)
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_4(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_4(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_4(
 
     global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float  * psrc2 = src2 != src0 ? (global float  *)(src2) : 0;
     global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_4(
     }
 
     // parallel max
-    float4 lmax4 = -INFINITY;
+    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
@@ -92,7 +96,11 @@ kernel void kernel_soft_max_4(
     }
     float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
         pdst4[i00] /= sum;
index 9d292b57465a531831999dddcb64630fc9bf392f..4baa6c28e4f0e719d5e26767c1d3cf64f28d5c64 100644 (file)
@@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16(
 
     global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global half  * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
     global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16(
     }
 
     // parallel max
-    float lmax = -INFINITY;
+    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
@@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16(
         pdst[i00] = exp_psrc0;
     }
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         pdst[i00] /= sum;
index 7c53dfbe5a27c85d38328cfddecb0e926ea89645..d503190b4765113a339d690b6035be1709875d8d 100644 (file)
@@ -26,6 +26,8 @@ kernel void kernel_soft_max(
         ulong offset0,
         global char * src1,
         ulong offset1,
+        global char * src2,
+        ulong offset2,
         global char * dst,
         ulong offsetd,
         int ne00,
@@ -48,6 +50,7 @@ kernel void kernel_soft_max(
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
+    src2 = src2 + offset2;
     dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
@@ -60,6 +63,7 @@ kernel void kernel_soft_max(
 
     global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
     global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0;
     global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
@@ -75,7 +79,7 @@ kernel void kernel_soft_max(
     }
 
     // parallel max
-    float lmax = -INFINITY;
+    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
@@ -91,7 +95,11 @@ kernel void kernel_soft_max(
         pdst[i00] = exp_psrc0;
     }
 
-    const float sum = sub_group_reduce_add(lsum);
+    float sum = sub_group_reduce_add(lsum);
+
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max);
+    }
 
     for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
         pdst[i00] /= sum;