]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: add fused `rms_norm_mul` (#14841)
authorlhez <redacted>
Fri, 25 Jul 2025 15:12:13 +0000 (08:12 -0700)
committerGitHub <redacted>
Fri, 25 Jul 2025 15:12:13 +0000 (17:12 +0200)
* opencl: add fused `rms_norm` + `mul`

* opencl: improve workgroup size for `rms_norm_mul`

ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/rms_norm.cl

index 63ac4a989b08b8da5ab7543ece9a7ba2cbd77ae3..c87a32383c8735ae49bd387f16ad16c82f5a243d 100644 (file)
@@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
     size_t max_alloc_size;
     bool fp16_support;
     bool has_vector_subgroup_broadcast;
+    bool disable_fusion;
     ggml_cl_compiler_version adreno_cl_compiler_version;
 
     int adreno_wave_size;
@@ -411,7 +412,7 @@ struct ggml_backend_opencl_context {
     cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
               kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
     cl_kernel kernel_norm;
-    cl_kernel kernel_rms_norm;
+    cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
     cl_kernel kernel_group_norm;
     cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
     cl_kernel kernel_soft_max, kernel_soft_max_4;
@@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         backend_ctx->program_rms_norm =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
+        CL_CHECK((backend_ctx->kernel_rms_norm     = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
+        CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
     CL_CHECK((backend_ctx->B_d_max   = clCreateBuffer(context, 0, max_B_d_bytes,   NULL, &err), err));
 #endif // GGML_OPENCL_USE_ADRENO_KERNELS
 
+    backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
+
     dev_ctx->backend_ctx = backend_ctx.release();
     return dev_ctx->backend_ctx;
 }
@@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) {
     sync_with_other_backends(backend_ctx);
 }
 
+static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
+    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+        return false;
+    }
+
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
+        const ggml_tensor *mul      = cgraph->nodes[node_idx+1];
+
+        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
+        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
+
+        // rms_norm only supports f32
+        if (mul->src[0]->type != GGML_TYPE_F32 ||
+            mul->src[1]->type != GGML_TYPE_F32 ||
+            mul->type != GGML_TYPE_F32) {
+            return false;
+        }
+
+        // if rms_norm is the B operand, then we don't handle broadcast
+        if (rms_norm == mul->src[1] &&
+            !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
+            return false;
+        }
+
+        // rms_norm assumes contiguous rows
+        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
+            return false;
+        }
+    }
+
+    return true;
+}
+
+static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
+
 static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
 
@@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
             continue;
         }
 
+        if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+            ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
+            i++;
+            continue;
+        }
+
         bool ok = ggml_cl_compute_forward(backend, node);
         if (!ok) {
             GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
@@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
     backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
 }
 
+static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
+    GGML_ASSERT(mul_tensor);
+    GGML_ASSERT(rms_norm_tensor);
+
+    // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
+    const ggml_tensor * src0 = rms_norm_tensor->src[0];
+    const ggml_tensor * src1;
+    if (mul_tensor->src[0] == rms_norm_tensor) {
+        src1 = mul_tensor->src[1];
+    } else if (mul_tensor->src[1] == rms_norm_tensor) {
+        src1 = mul_tensor->src[0];
+    } else {
+        GGML_ASSERT(false && "Invalid args for rms_norm and mul");
+    }
+    const ggml_tensor * dst = mul_tensor;
+
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+
+    float eps;
+    memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));
+
+    const int ne00 = src0->ne[0];
+    const int ne01 = src0->ne[1];
+    const int ne02 = src0->ne[2];
+    const int ne03 = src0->ne[3];
+
+    const cl_ulong nb01 = src0->nb[1];
+    const cl_ulong nb02 = src0->nb[2];
+    const cl_ulong nb03 = src0->nb[3];
+
+    const int ne10 = src1->ne[0];
+    const int ne11 = src1->ne[1];
+    const int ne12 = src1->ne[2];
+    const int ne13 = src1->ne[3];
+
+    const cl_ulong nb11 = src1->nb[1];
+    const cl_ulong nb12 = src1->nb[2];
+    const cl_ulong nb13 = src1->nb[3];
+
+    const cl_ulong nb1 = dst->nb[1];
+    const cl_ulong nb2 = dst->nb[2];
+    const cl_ulong nb3 = dst->nb[3];
+
+    GGML_ASSERT(ne00 % 4 == 0);
+
+    size_t sgs;
+    if (backend_ctx->gpu_family == ADRENO) {
+        sgs = 64;
+    } else if (backend_ctx->gpu_family == INTEL) {
+        sgs = 32;
+    } else {
+        GGML_ASSERT(false && "Unsupported GPU");
+    }
+
+    cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
+
+    int nth = sgs;
+    int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
+    while (nth < ne00 && nth < max_workgroup_size) {
+        nth *= 2;
+    }
+    nth = MIN(nth, max_workgroup_size);
+    nth = MIN(nth, ne00);
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),        &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong),      &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),        &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong),      &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),        &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong),      &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),           &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),           &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),           &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),           &ne03));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),      &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),      &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),      &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),           &ne10));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),           &ne11));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),           &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),           &ne13));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),      &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),      &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),      &nb13));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong),      &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong),      &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong),      &nb3));
+    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float),         &eps));
+    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
+
+    backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
+}
+
 static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0);
     GGML_ASSERT(src0->extra);
index 9d21f3398ec38bf343ba61512a1253aeaf19efd3..ecd053cb4c1ce7556c57d765fa4ba3abf112b90d 100644 (file)
@@ -94,3 +94,82 @@ kernel void kernel_rms_norm(
         }
     }
 }
+
+//------------------------------------------------------------------------------
+// rms_norm_mul
+//------------------------------------------------------------------------------
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_32
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_rms_norm_mul(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global char * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        int ne13,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        float eps,
+        local float * sum
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+    global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
+
+    float sumf = 0;
+
+    // parallel sum
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        sumf += dot(x[i00], x[i00]);
+    }
+    sumf = sub_group_reduce_add(sumf);
+    if (get_sub_group_local_id() == 0) {
+        sum[get_sub_group_id()] = sumf;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
+       if (get_local_id(0) < i) {
+           sum[get_local_id(0)] += sum[get_local_id(0) + i];
+       }
+    }
+    if (get_local_id(0) == 0) {
+        sum[0] /= ne00;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    float mean  = sum[0];
+    float scale = 1.0f/sqrt(mean + eps);
+
+    global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
+    }
+}