]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: allow mixed f16/f32 `add` (#15140)
authorrmatif <redacted>
Tue, 12 Aug 2025 09:42:41 +0000 (11:42 +0200)
committerGitHub <redacted>
Tue, 12 Aug 2025 09:42:41 +0000 (02:42 -0700)
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/add.cl

index 8ba1e00df7bad4dd3e98f811ade06e39720e0d86..fc838684ac6fa612d831490a6bb87df3c1d8b2c6 100644 (file)
@@ -2481,6 +2481,13 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
         case GGML_OP_SCALE:
             return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
         case GGML_OP_ADD:
+            if (op->type == GGML_TYPE_F16) {
+                const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32;
+                const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32;
+                if (src0_ok && src1_ok) {
+                    return true;
+                }
+            }
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_SUB:
@@ -3717,34 +3724,30 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
 
-    GGML_ASSERT(src0->type == src1->type);
-    GGML_ASSERT(src0->type == dst->type);
-    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
-
-    const int  ne00 = src0->ne[0];
-    const int  ne01 = src0->ne[1];
-    const int  ne02 = src0->ne[2];
-    const int  ne03 = src0->ne[3];
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
 
     const cl_ulong nb00 = src0->nb[0];
     const cl_ulong nb01 = src0->nb[1];
     const cl_ulong nb02 = src0->nb[2];
     const cl_ulong nb03 = src0->nb[3];
 
-    const int  ne10 = src1->ne[0];
-    const int  ne11 = src1->ne[1];
-    const int  ne12 = src1->ne[2];
-    const int  ne13 = src1->ne[3]; UNUSED(ne13);
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
 
     const cl_ulong nb10 = src1->nb[0];
     const cl_ulong nb11 = src1->nb[1];
     const cl_ulong nb12 = src1->nb[2];
-    const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
+    const cl_ulong nb13 = src1->nb[3];
 
-    const int  ne0  = dst->ne[0];
-    const int  ne1  = dst->ne[1];
-    const int  ne2  = dst->ne[2];
-    const int  ne3  = dst->ne[3];
+    const int ne0  = dst->ne[0];
+    const int ne1  = dst->ne[1];
+    const int ne2  = dst->ne[2];
+    const int ne3  = dst->ne[3];
 
     const cl_ulong nb0  = dst->nb[0];
     const cl_ulong nb1  = dst->nb[1];
@@ -3761,68 +3764,114 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
     cl_ulong offset1 = extra1->offset + src1->view_offs;
     cl_ulong offsetd = extrad->offset + dst->view_offs;
 
-    bool bcast_row = false;
     cl_kernel kernel;
 
-    if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
-        GGML_ASSERT(ggml_is_contiguous(src0));
+    const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0;
 
-        // src1 is a row
+    if (bcast_row) {
+        GGML_ASSERT(ggml_is_contiguous(src0));
         GGML_ASSERT(ne11 == 1);
+    }
 
-        bcast_row = true;
-        int ne = ne00 / 4;
-
-        if (src0->type == GGML_TYPE_F32) {
+    if (dst->type == GGML_TYPE_F32) {
+        GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32);
+        if (bcast_row) {
             kernel = backend_ctx->kernel_add_row;
+            const int ne = ne00 / 4;
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
         } else {
-            kernel = backend_ctx->kernel_add_row_f16;
-        }
-
-        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
-        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
-        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
-        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
-        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
-        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
-        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
-    } else {
-        if (src0->type == GGML_TYPE_F32) {
             kernel = backend_ctx->kernel_add;
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
+            CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
+            CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+            CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+            CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+            CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+        }
+    } else if (dst->type == GGML_TYPE_F16) {
+        GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
+        GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
+        const int type_src0 = (src0->type == GGML_TYPE_F32);
+        const int type_src1 = (src1->type == GGML_TYPE_F32);
+        if (bcast_row) {
+            kernel = backend_ctx->kernel_add_row_f16;
+            const int ne = ne00 / 4;
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
+            CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),      &type_src0));
+            CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int),      &type_src1));
         } else {
             kernel = backend_ctx->kernel_add_f16;
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
+            CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
+            CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+            CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+            CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+            CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+            CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int),      &type_src0));
+            CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int),      &type_src1));
         }
-
-        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
-        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
-        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
-        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
-        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
-        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
-        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
-        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
-        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
-        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
-        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
-        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
-        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
-        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
-        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
-        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
-        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
-        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
-        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
-        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
-        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
-        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
-        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
-        CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
-        CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
-        CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
-        CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
-        CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
-        CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
-        CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+    } else {
+        GGML_ASSERT(false && "unsupported data types for add");
     }
 
     if (bcast_row) {
@@ -3832,13 +3881,13 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
 
         size_t * local_work_size_ptr = local_work_size;
         if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
-            local_work_size_ptr = nullptr;  // Let driver choose the work-group sizes.
+            local_work_size_ptr = nullptr;
         }
 
-        backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
+        backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst);
     } else {
         unsigned int nth = MIN(64, ne0);
-        size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
         size_t local_work_size[] = {nth, 1, 1};
 
         backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
index 8bc926c88931f43f2c30b93fa261e0db9d9e0aad..509bf17344ea6cef3a24b23db2c8c0c4796e4596 100644 (file)
@@ -112,7 +112,9 @@ kernel void kernel_add_f16(
         ulong nb0,
         ulong nb1,
         ulong nb2,
-        ulong nb3
+        ulong nb3,
+        int type_src0,
+        int type_src1
 ) {
     src0 = src0 + offset0;
     src1 = src1 + offset1;
@@ -132,25 +134,57 @@ kernel void kernel_add_f16(
 
     for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
         const int i10 = i0 % ne10;
-        *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) + *((global half *)(src1_ptr + i10*nb10));
+
+        half v0, v1;
+        if (type_src0 == 1) {
+            v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
+        } else {
+            v0 = *((global half *)(src0_ptr + i0*nb00));
+        }
+
+        if (type_src1 == 1) {
+            v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
+        } else {
+            v1 = *((global half *)(src1_ptr + i10*nb10));
+        }
+
+        *((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
     }
 }
 
 kernel void kernel_add_row_f16(
-        global half4 * src0,
+        global char * src0,
         ulong  offset0,
-        global half4 * src1,
+        global char * src1,
         ulong  offset1,
         global half4 * dst,
         ulong  offsetd,
-        int ne
+        int ne,
+        int type_src0,
+        int type_src1
 ) {
-    src0 = (global half4*)((global char*)src0 + offset0);
-    src1 = (global half4*)((global char*)src1 + offset1);
     dst = (global half4*)((global char*)dst + offsetd);
 
     // This performs better than using %.
     uint gid = get_global_id(0);
     uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
-    dst[gid] = src0[gid] + src1[idx1];
+
+    half4 v0, v1;
+    if (type_src0 == 1) {
+        global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
+        v0 = convert_half4(src0_f32[gid]);
+    } else {
+        global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
+        v0 = src0_f16[gid];
+    }
+
+    if (type_src1 == 1) {
+        global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
+        v1 = convert_half4(src1_f32[idx1]);
+    } else {
+        global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
+        v1 = src1_f16[idx1];
+    }
+
+    dst[gid] = v0 + v1;
 }