]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: add `swiglu_oai` and `add_id` (#15121)
authorlhez <redacted>
Wed, 6 Aug 2025 19:12:17 +0000 (04:12 +0900)
committerGitHub <redacted>
Wed, 6 Aug 2025 19:12:17 +0000 (12:12 -0700)
* opencl: add `swiglu-oai`

* opencl: add `add_id`

* opencl: add missing `add_id.cl`

ggml/src/ggml-opencl/CMakeLists.txt
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/add_id.cl [new file with mode: 0644]
ggml/src/ggml-opencl/kernels/glu.cl

index 3adea83615437b8951e7602285a8c4198e5e2c9a..d8290faa467d01d4ea3c5a0af032908adb266bfe 100644 (file)
@@ -55,6 +55,7 @@ endfunction()
 
 set(GGML_OPENCL_KERNELS
     add
+    add_id
     argsort
     clamp
     cpy
index bb8b310b983e270135b5efd0061bc18442c0def3..eea6ad6cab3d0cc2b9e80c05a34ef6d926b6f466 100644 (file)
@@ -345,6 +345,7 @@ struct ggml_backend_opencl_context {
     cl_command_queue queue;
 
     cl_program program_add;
+    cl_program program_add_id;
     cl_program program_clamp;
     cl_program program_cpy;
     cl_program program_cvt;
@@ -404,6 +405,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_mul, kernel_mul_row, kernel_mul_f16, kernel_mul_row_f16;
     cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
     cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
+    cl_kernel kernel_add_id;
     cl_kernel kernel_scale;
     cl_kernel kernel_silu, kernel_silu_4;
     cl_kernel kernel_gelu, kernel_gelu_4;
@@ -412,7 +414,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_relu;
     cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16;
     cl_kernel kernel_clamp;
-    cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
+    cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick,
               kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
     cl_kernel kernel_norm;
     cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
@@ -681,6 +683,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         GGML_LOG_CONT(".");
     }
 
+    // add_id
+    {
+#ifdef GGML_OPENCL_EMBED_KERNELS
+        const std::string kernel_src {
+            #include "add_id.cl.h"
+        };
+#else
+        const std::string kernel_src = read_file("add_id.cl");
+#endif
+        backend_ctx->program_add_id =
+            build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
+
+        CL_CHECK((backend_ctx->kernel_add_id = clCreateKernel(backend_ctx->program_add_id, "kernel_add_id", &err), err));
+        GGML_LOG_CONT(".");
+    }
+
     // clamp
     {
 #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -787,6 +805,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         CL_CHECK((backend_ctx->kernel_geglu           = clCreateKernel(backend_ctx->program_glu, "kernel_geglu", &err), err));
         CL_CHECK((backend_ctx->kernel_reglu           = clCreateKernel(backend_ctx->program_glu, "kernel_reglu", &err), err));
         CL_CHECK((backend_ctx->kernel_swiglu          = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu", &err), err));
+        CL_CHECK((backend_ctx->kernel_swiglu_oai      = clCreateKernel(backend_ctx->program_glu, "kernel_swiglu_oai", &err), err));
         CL_CHECK((backend_ctx->kernel_geglu_erf       = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_erf", &err), err));
         CL_CHECK((backend_ctx->kernel_geglu_quick     = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_quick", &err), err));
         CL_CHECK((backend_ctx->kernel_geglu_f16       = clCreateKernel(backend_ctx->program_glu, "kernel_geglu_f16", &err), err));
@@ -2467,6 +2486,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
             return (op->src[0]->type == op->src[1]->type) &&
                    (op->src[0]->type == op->type) &&
                    (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
+        case GGML_OP_ADD_ID:
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_GELU:
@@ -2488,6 +2509,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
@@ -3824,6 +3846,75 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const
     }
 }
 
+static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const ggml_tensor * src2 = dst->src[2];
+    GGML_ASSERT(src2);
+    GGML_ASSERT(src2->extra);
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src2->type == GGML_TYPE_I32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous_rows(src0));
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+
+    const cl_ulong nb11 = src1->nb[1];
+
+    const cl_ulong nb21 = src2->nb[1];
+
+    const int ne0 = dst->ne[0];
+    const int ne1 = dst->ne[1];
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offset2 = extra2->offset + src2->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel = backend_ctx->kernel_add_id;
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extra2->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne0));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne1));
+
+    int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));
+    size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };
+    size_t local_work_size[] = { (size_t)nth, 1, 1 };
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
@@ -7005,6 +7096,9 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
                 kernel = backend_ctx->kernel_swiglu_f16;
             }
             break;
+        case GGML_GLU_OP_SWIGLU_OAI:
+            kernel = backend_ctx->kernel_swiglu_oai;
+            break;
         case GGML_GLU_OP_GEGLU_ERF:
             if (dst->type == GGML_TYPE_F32) {
                 kernel = backend_ctx->kernel_geglu_erf;
@@ -7040,7 +7134,10 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
 
     const cl_ulong nb1  = dst->nb[1];
 
-    const int swp = ((const int32_t *) dst->op_params)[1];
+    const int   swp   = ggml_get_op_params_i32(dst, 1);
+    const float alpha = ggml_get_op_params_f32(dst, 2);
+    const float limit = ggml_get_op_params_f32(dst, 3);
+
     const int ne00_off = src1 ? 0 : (swp ? ne0 : 0);
     const int ne10_off = src1 ? 0 : (swp ? 0 : ne0);
 
@@ -7057,6 +7154,11 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const
     CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne00_off));
     CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne10_off));
 
+    if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) {
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &limit));
+        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &alpha));
+    }
+
     const size_t nrows = ggml_nrows(src0);
     size_t nth = 512;
     size_t global_work_size[] = {nrows*nth, 1, 1};
@@ -7113,6 +7215,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
             }
             func = ggml_cl_add;
             break;
+        case GGML_OP_ADD_ID:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_add_id;
+            break;
         case GGML_OP_MUL:
             if (!any_on_device) {
                 return false;
diff --git a/ggml/src/ggml-opencl/kernels/add_id.cl b/ggml/src/ggml-opencl/kernels/add_id.cl
new file mode 100644 (file)
index 0000000..e9c6d55
--- /dev/null
@@ -0,0 +1,42 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+//------------------------------------------------------------------------------
+// add_id
+//------------------------------------------------------------------------------
+kernel void kernel_add_id(
+    global char * src0,
+    ulong         offset0,
+    global char * src1,
+    ulong         offset1,
+    global char * src2,
+    ulong         offset2,
+    global char * dst,
+    ulong         offsetd,
+    ulong         nb01,
+    ulong         nb02,
+    ulong         nb11,
+    ulong         nb21,
+    int           ne0,
+    int           ne1
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    src2 = (global char*)((global char*)src2 + offset2);
+    dst  = (global char*)((global char*)dst  + offsetd);
+
+    int i1 = get_group_id(0);
+    int i2 = get_group_id(1);
+
+    const int i11 = *((global const int *) (src2 + i1*sizeof(int) + i2*nb21));
+
+    const size_t nb1 = ne0 * sizeof(float);
+    const size_t nb2 = ne1 * nb1;
+
+    global float * dst_row  = (global float *)((global char *)dst  + i1*nb1 + i2*nb2);
+    global float * src0_row = (global float *)((global char *)src0 + i1*nb01 + i2*nb02);
+    global float * src1_row = (global float *)((global char *)src1 + i11*nb11);
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        dst_row[i0] = src0_row[i0] + src1_row[i0];
+    }
+}
index 7cca16e6a9e7e5f271c2f328fa1fc9d730bc163a..059a4bbf1ba7cc80fd734bbd6af36a01c449e65a 100644 (file)
@@ -202,6 +202,47 @@ kernel void kernel_swiglu_f16(
     }
 }
 
+//------------------------------------------------------------------------------
+// swiglu_oai
+//------------------------------------------------------------------------------
+kernel void kernel_swiglu_oai(
+    global char * src0,
+    ulong         offset0,
+    global char * src1,
+    ulong         offset1,
+    global char * dst,
+    ulong         offsetd,
+    ulong         nb01,
+    ulong         nb11,
+    int           ne0,
+    ulong         nb1,
+    int           ne00_off,
+    int           ne10_off,
+    float         limit,
+    float         alpha
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst  = (global char*)((global char*)dst  + offsetd);
+
+    global float * src0_row = (global float *) ((global char *) src0 + get_group_id(0)*nb01) + ne00_off;
+    global float * src1_row = (global float *) ((global char *) src1 + get_group_id(0)*nb11) + ne10_off;
+    global float * dst_row  = (global float *) ((global char *) dst  + get_group_id(0)*nb1);
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        float x0 = src0_row[i0];
+        float x1 = src1_row[i0];
+
+        x0 = min(x0, limit);
+        x1 = max(min(x1, limit), -limit);
+
+        float out_glu = x0 / (1.0f + exp(-x0 * alpha));
+        out_glu = out_glu * (1.0f + x1);
+
+        dst_row[i0] = out_glu;
+    }
+}
+
 //------------------------------------------------------------------------------
 // geglu_erf
 //------------------------------------------------------------------------------