]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl : broadcast for soft_max (#14510)
authorlhez <redacted>
Thu, 3 Jul 2025 18:22:24 +0000 (11:22 -0700)
committerGitHub <redacted>
Thu, 3 Jul 2025 18:22:24 +0000 (20:22 +0200)
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/softmax_4_f16.cl
ggml/src/ggml-opencl/kernels/softmax_4_f32.cl
ggml/src/ggml-opencl/kernels/softmax_f16.cl
ggml/src/ggml-opencl/kernels/softmax_f32.cl

index 9436e6ea9a08de7801b60fe4e4291dc3d931048b..2450100b43c95120273c92bf9d69edc2d64c34b1 100644 (file)
@@ -5763,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
 
     cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
 
-    const int  ne00 = src0 ? src0->ne[0] : 0;
-    const int  ne01 = src0 ? src0->ne[1] : 0;
-    const int  ne02 = src0 ? src0->ne[2] : 0;
-    const int  ne03 = src0 ? src0->ne[3] : 0;
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const cl_long nb01 = src0->nb[1];
+    const cl_long nb02 = src0->nb[2];
+    const cl_long nb03 = src0->nb[3];
+
+    const int ne12 = src1 ? src1->ne[2] : 0;
+    const int ne13 = src1 ? src1->ne[3] : 0;
+
+    const cl_long nb11 = src1 ? src1->nb[1] : 0;
+    const cl_long nb12 = src1 ? src1->nb[2] : 0;
+    const cl_long nb13 = src1 ? src1->nb[3] : 0;
+
+    const cl_long nb1 = dst->nb[1];
+    const cl_long nb2 = dst->nb[2];
+    const cl_long nb3 = dst->nb[3];
 
     float scale, max_bias;
     memcpy(&scale,    dst->op_params + 0, sizeof(float));
     memcpy(&max_bias, dst->op_params + 1, sizeof(float));
 
-    const int nrows_x = ggml_nrows(src0);
-    const int nrows_y = src0->ne[1];
-
-    const int n_head      = nrows_x/nrows_y;
+    const int n_head      = src0->ne[2];
     const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
 
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
@@ -5820,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
     CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
     CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
     CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
-    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
-    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
-    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(float),    &scale));
-    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float),    &max_bias));
-    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float),    &m0));
-    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float),    &m1));
-    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &n_head_log2));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne13));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float),    &scale));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float),    &max_bias));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float),    &m0));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float),    &m1));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &n_head_log2));
 
     size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
     size_t local_work_size[] = {(size_t)nth, 1, 1};
index 62c05369a87b14f132ba4e5cf211986245eaa2b8..a6d8ede67010db92a11133d14e9bb366c406e0dd 100644 (file)
 REQD_SUBGROUP_SIZE_64
 #endif
 kernel void kernel_soft_max_4_f16(
-        global float * src0,
+        global char * src0,
         ulong offset0,
-        global half * src1,
+        global char * src1,
         ulong offset1,
-        global float * dst,
+        global char * dst,
         ulong offsetd,
         int ne00,
-        int ne01,
-        int ne02,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne12,
+        int ne13,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
         float scale,
         float max_bias,
         float m0,
         float m1,
         int n_head_log2
 ) {
-    src0 = (global float *)((global char *)src0 + offset0);
-    src1 = (global half *)((global char *)src1 + offset1);
-    dst = (global float *)((global char *)dst + offsetd);
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
     int i02 = get_group_id(1);
     int i01 = get_group_id(0);
 
-    global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    global half4  * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0;
-    global float4 * pdst4 = (global float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    int i13 = i03%ne13;
+    int i12 = i02%ne12;
+    int i11 = i01;
+
+    global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
+    global half4  * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
 
index d562774eaba5e6c9dbe4c638fa4f5822ead8595a..35b5573b46a81bd893dcdf1d9ac2e0193a3fcffd 100644 (file)
 REQD_SUBGROUP_SIZE_64
 #endif
 kernel void kernel_soft_max_4(
-        global float * src0,
+        global char * src0,
         ulong offset0,
-        global float * src1,
+        global char * src1,
         ulong offset1,
-        global float * dst,
+        global char * dst,
         ulong offsetd,
         int ne00,
-        int ne01,
-        int ne02,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne12,
+        int ne13,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
         float scale,
         float max_bias,
         float m0,
         float m1,
         int n_head_log2
 ) {
-    src0 = (global float*)((global char*)src0 + offset0);
-    src1 = (global float*)((global char*)src1 + offset1);
-    dst = (global float*)((global char*)dst + offsetd);
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
     int i02 = get_group_id(1);
     int i01 = get_group_id(0);
 
-    global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
-    global float4 * pdst4 = (global float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    int i13 = i03%ne13;
+    int i12 = i02%ne12;
+    int i11 = i01;
+
+    global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
+    global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float4 * pdst4 = (global float4 *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
 
index d38d099671ecf61dc88f2d814b1a17f933228980..9d292b57465a531831999dddcb64630fc9bf392f 100644 (file)
 REQD_SUBGROUP_SIZE_64
 #endif
 kernel void kernel_soft_max_f16(
-        global float * src0,
+        global char * src0,
         ulong offset0,
-        global half * src1,
+        global char * src1,
         ulong offset1,
-        global float * dst,
+        global char * dst,
         ulong offsetd,
         int ne00,
-        int ne01,
-        int ne02,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne12,
+        int ne13,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
         float scale,
         float max_bias,
         float m0,
         float m1,
         int n_head_log2
 ) {
-    src0 = (global float *)((global char *)src0 + offset0);
-    src1 = (global half *)((global char *)src1 + offset1);
-    dst = (global float *)((global char *)dst + offsetd);
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
     int i02 = get_group_id(1);
     int i01 = get_group_id(0);
 
-    global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    global half  * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0;
-    global float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    int i13 = i03%ne13;
+    int i12 = i02%ne12;
+    int i11 = i01;
+
+    global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
+    global half  * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;
 
index 001b587abe31e5c00b826c7122dbd40d9d83d50d..7c53dfbe5a27c85d38328cfddecb0e926ea89645 100644 (file)
 REQD_SUBGROUP_SIZE_64
 #endif
 kernel void kernel_soft_max(
-        global float * src0,
+        global char * src0,
         ulong offset0,
-        global float * src1,
+        global char * src1,
         ulong offset1,
-        global float * dst,
+        global char * dst,
         ulong offsetd,
         int ne00,
-        int ne01,
-        int ne02,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne12,
+        int ne13,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
         float scale,
         float max_bias,
         float m0,
         float m1,
         int n_head_log2
 ) {
-    src0 = (global float*)((global char*)src0 + offset0);
-    src1 = (global float*)((global char*)src1 + offset1);
-    dst = (global float*)((global char*)dst + offsetd);
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
 
     int i03 = get_group_id(2);
     int i02 = get_group_id(1);
     int i01 = get_group_id(0);
 
-    global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
-    global float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    int i13 = i03%ne13;
+    int i12 = i02%ne12;
+    int i11 = i01;
+
+    global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
+    global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
+    global float * pdst  = (global float *)(dst  + i01*nb1 + i02*nb2 + i03*nb3);
 
     float slope = 1.0f;