]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fuse adds, fuse add with rms norm (#15631)
authorAman Gupta <redacted>
Fri, 29 Aug 2025 03:35:58 +0000 (11:35 +0800)
committerGitHub <redacted>
Fri, 29 Aug 2025 03:35:58 +0000 (11:35 +0800)
* CUDA: fused add with rms_norm_mul

* Non-broadcast fuse works

* Add fused adds

* format

* Remove n_fuse from template params

* Address review comments

* Move template inside binbcast

ggml/src/ggml-cuda/binbcast.cu
ggml/src/ggml-cuda/binbcast.cuh
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/norm.cu
ggml/src/ggml-cuda/norm.cuh

index e1fbf0e13665d016e68d19c07f77e6970dfb3301..99a98fcbfcdb36fe364b575929c556da93eb910d 100644 (file)
@@ -1,5 +1,6 @@
 #include "binbcast.cuh"
 #include <cstdint>
+#include <utility>
 
 static __device__ __forceinline__ float op_repeat(const float a, const float b) {
     return b;
@@ -22,13 +23,16 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
     return a / b;
 }
 
-template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
+
+
+template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
 static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        int ne0, int ne1, int ne2, int ne3,
-        int ne10, int ne11, int ne12, int ne13,
-        /*int s0, */ int s1,  int s2,  int s3,
-        /*int s00,*/ int s01, int s02, int s03,
-        /*int s10,*/ int s11, int s12, int s13) {
+        const int ne0, const int ne1, const int ne2, const int ne3,
+        const int ne10, const int ne11, const int ne12, const int ne13,
+        /*int s0, */ const int s1, const int s2, const int s3,
+        /*int s00,*/ const int s01, const int s02, const int s03,
+        /*int s10,*/ const int s11, const int s12, const int s13,
+        src1_ptrs... src1s) {
     const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
     const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
     const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
@@ -46,24 +50,27 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
     const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
 
-    const src0_t * src0_row = src0 + i_src0;
-    const src1_t * src1_row = src1 + i_src1;
+    const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
     dst_t * dst_row = dst + i_dst;
 
     for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
         const int i10 = i0 % ne10;
-        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+        float result = src0_row ? (float) src0_row[i0] : 0.0f;
+        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+
+        dst_row[i0] = (dst_t) result;
     }
 }
 
-template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
-static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        int ne0, int ne1, int ne2, int ne3,
-        int ne10, int ne11, int ne12, int ne13,
-        /*int s0, */ int s1,  int s2,  int s3,
-        /*int s00,*/ int s01, int s02, int s03,
-        /*int s10,*/ int s11, int s12, int s13) {
-
+template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
+static __global__ void k_bin_bcast_unravel(const src0_t *   src0, const src1_t *   src1, dst_t *          dst,
+        const int ne0, const int ne1, const int ne2,const int ne3,
+        const int ne10, const int ne11, const int ne12, const int ne13,
+        /*int s0, */ const int s1, const int s2, const int s3,
+        /*int s00,*/ const int s01, const int s02, const int s03,
+        /*int s10,*/ const int s11, const int s12, const int s13,
+        src1_ptrs ... src1s) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     const int i3 = i/(ne2*ne1*ne0);
@@ -83,12 +90,166 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
     const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
 
-    const src0_t * src0_row = src0 + i_src0;
-    const src1_t * src1_row = src1 + i_src1;
+    const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
     dst_t * dst_row = dst + i_dst;
 
     const int i10 = i0 % ne10;
-    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+
+    float result = src0_row ? (float) src0_row[i0] : 0.0f;
+    result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+
+    dst_row[i0] = (dst_t) result;
+}
+
+template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
+static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                                  const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+                                  cudaStream_t stream, std::index_sequence<I...>) {
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    int nr0 = ne10 / ne0;
+    int nr1 = ne11 / ne1;
+    int nr2 = ne12 / ne2;
+    int nr3 = ne13 / ne3;
+
+    int nr[4] = { nr0, nr1, nr2, nr3 };
+
+    int64_t cne[]  = { ne0, ne1, ne2, ne3 };
+    int64_t cne0[] = { ne00, ne01, ne02, ne03 };
+    int64_t cne1[] = { ne10, ne11, ne12, ne13 };
+
+    size_t cnb[]  = { nb0, nb1, nb2, nb3 };
+    size_t cnb0[] = { nb00, nb01, nb02, nb03 };
+    size_t cnb1[] = { nb10, nb11, nb12, nb13 };
+
+    auto collapse = [](int64_t cne[]) {
+        cne[0] *= cne[1];
+        cne[1] = cne[2];
+        cne[2] = cne[3];
+        cne[3] = 1;
+    };
+
+    auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
+        cnb[1] *= cne[1];
+        cnb[2] *= cne[2];
+        cnb[3] *= cne[3];
+    };
+
+    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
+        for (int i = 0; i < 4; i++) {
+            if (nr[i] != 1) {
+                break;
+            }
+            if (i > 0) {
+                collapse_nb(cnb, cne);
+                collapse_nb(cnb0, cne0);
+                collapse_nb(cnb1, cne1);
+                collapse(cne);
+                collapse(cne0);
+                collapse(cne1);
+            }
+        }
+    }
+
+    {
+        int64_t ne0 = cne[0];
+        int64_t ne1 = cne[1];
+        int64_t ne2 = cne[2];
+        int64_t ne3 = cne[3];
+
+        //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+        //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+        //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+        //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
+
+        int64_t ne10 = cne1[0];
+        int64_t ne11 = cne1[1];
+        int64_t ne12 = cne1[2];
+        int64_t ne13 = cne1[3];
+
+        size_t nb0 = cnb[0];
+        size_t nb1 = cnb[1];
+        size_t nb2 = cnb[2];
+        size_t nb3 = cnb[3];
+
+        size_t nb00 = cnb0[0];
+        size_t nb01 = cnb0[1];
+        size_t nb02 = cnb0[2];
+        size_t nb03 = cnb0[3];
+
+        size_t nb10 = cnb1[0];
+        size_t nb11 = cnb1[1];
+        size_t nb12 = cnb1[2];
+        size_t nb13 = cnb1[3];
+
+        size_t s0 = nb0 / sizeof(dst_t);
+        size_t s1 = nb1 / sizeof(dst_t);
+        size_t s2 = nb2 / sizeof(dst_t);
+        size_t s3 = nb3 / sizeof(dst_t);
+
+        size_t s10 = nb10 / sizeof(src1_t);
+        size_t s11 = nb11 / sizeof(src1_t);
+        size_t s12 = nb12 / sizeof(src1_t);
+        size_t s13 = nb13 / sizeof(src1_t);
+
+        size_t s00 = nb00 / sizeof(src0_t);
+        size_t s01 = nb01 / sizeof(src0_t);
+        size_t s02 = nb02 / sizeof(src0_t);
+        size_t s03 = nb03 / sizeof(src0_t);
+
+        GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+        GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+        GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+        GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+        GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+        GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+        GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+        GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+        GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+        GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+        GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+        GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+        GGML_ASSERT(s0 == 1);
+        GGML_ASSERT(s00 == 1);
+        GGML_ASSERT(s10 == 1);
+
+        const int block_size = 128;
+
+        int64_t hne0 = std::max(ne0 / 2LL, 1LL);
+
+        dim3 block_dims;
+        block_dims.x = std::min<unsigned int>(hne0, block_size);
+        block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
+        block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+        dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
+                        (ne1 + block_dims.y - 1) / block_dims.y,
+                        (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
+
+        if (block_nums.z > 65535) {
+            int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+            k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
+                <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s00,*/ s01, s02, s03,
+                    /* s10,*/ s11, s12,s13,
+                    (const src1_t *) dst->src[I + 1]->data...);
+        } else {
+            k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
+                <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
+                    ne0, ne1, ne2, ne3,
+                    ne10, ne11, ne12, ne13,
+                    /* s0, */ s1, s2, s3,
+                    /* s00,*/ s01, s02, s03,
+                    /* s10,*/ s11, s12,s13,
+                    (const src1_t *) dst->src[I + 1]->data...);
+        }
+    }
 }
 
 template <typename T>
@@ -120,160 +281,14 @@ static __global__ void k_repeat_back(
     dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
 }
 
-template<float (*bin_op)(const float, const float)>
+template <float (*bin_op)(const float, const float), int n_fuse = 1>
 struct bin_bcast_cuda {
     template<typename src0_t, typename src1_t, typename dst_t>
     void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
             const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
             cudaStream_t stream) {
-
-        GGML_TENSOR_BINARY_OP_LOCALS
-
-        int nr0 = ne10/ne0;
-        int nr1 = ne11/ne1;
-        int nr2 = ne12/ne2;
-        int nr3 = ne13/ne3;
-
-        int nr[4] = { nr0, nr1, nr2, nr3 };
-
-        // collapse dimensions until first broadcast dimension
-        int64_t cne[] = {ne0, ne1, ne2, ne3};
-        int64_t cne0[] = {ne00, ne01, ne02, ne03};
-        int64_t cne1[] = {ne10, ne11, ne12, ne13};
-
-        size_t cnb[] = {nb0, nb1, nb2, nb3};
-        size_t cnb0[] = {nb00, nb01, nb02, nb03};
-        size_t cnb1[] = {nb10, nb11, nb12, nb13};
-
-        auto collapse = [](int64_t cne[]) {
-            cne[0] *= cne[1];
-            cne[1] = cne[2];
-            cne[2] = cne[3];
-            cne[3] = 1;
-        };
-
-        auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
-            cnb[1] *= cne[1];
-            cnb[2] *= cne[2];
-            cnb[3] *= cne[3];
-        };
-
-        if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
-            for (int i = 0; i < 4; i++) {
-                if (nr[i] != 1) {
-                    break;
-                }
-                if (i > 0) {
-                    collapse_nb(cnb, cne);
-                    collapse_nb(cnb0, cne0);
-                    collapse_nb(cnb1, cne1);
-                    collapse(cne);
-                    collapse(cne0);
-                    collapse(cne1);
-                }
-            }
-        }
-
-        {
-            int64_t ne0 = cne[0];
-            int64_t ne1 = cne[1];
-            int64_t ne2 = cne[2];
-            int64_t ne3 = cne[3];
-
-            //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
-            //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
-            //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
-            //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
-
-            int64_t ne10 = cne1[0];
-            int64_t ne11 = cne1[1];
-            int64_t ne12 = cne1[2];
-            int64_t ne13 = cne1[3];
-
-            size_t nb0 = cnb[0];
-            size_t nb1 = cnb[1];
-            size_t nb2 = cnb[2];
-            size_t nb3 = cnb[3];
-
-            size_t nb00 = cnb0[0];
-            size_t nb01 = cnb0[1];
-            size_t nb02 = cnb0[2];
-            size_t nb03 = cnb0[3];
-
-            size_t nb10 = cnb1[0];
-            size_t nb11 = cnb1[1];
-            size_t nb12 = cnb1[2];
-            size_t nb13 = cnb1[3];
-
-            size_t s0 = nb0 / sizeof(dst_t);
-            size_t s1 = nb1 / sizeof(dst_t);
-            size_t s2 = nb2 / sizeof(dst_t);
-            size_t s3 = nb3 / sizeof(dst_t);
-
-            size_t s10 = nb10 / sizeof(src1_t);
-            size_t s11 = nb11 / sizeof(src1_t);
-            size_t s12 = nb12 / sizeof(src1_t);
-            size_t s13 = nb13 / sizeof(src1_t);
-
-            size_t s00 = nb00 / sizeof(src0_t);
-            size_t s01 = nb01 / sizeof(src0_t);
-            size_t s02 = nb02 / sizeof(src0_t);
-            size_t s03 = nb03 / sizeof(src0_t);
-
-            GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
-            GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
-            GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
-            GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
-
-            GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
-            GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
-            GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
-            GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
-
-            GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
-            GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
-            GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
-            GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
-
-            GGML_ASSERT(s0 == 1);
-            GGML_ASSERT(s00 == 1);
-            GGML_ASSERT(s10 == 1);
-
-            const int block_size = 128;
-
-            int64_t hne0 = std::max(ne0/2LL, 1LL);
-
-            dim3 block_dims;
-            block_dims.x = std::min<unsigned int>(hne0, block_size);
-            block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
-            block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
-
-            dim3 block_nums(
-                (hne0 + block_dims.x - 1) / block_dims.x,
-                (ne1 + block_dims.y - 1) / block_dims.y,
-                (ne2*ne3 + block_dims.z - 1) / block_dims.z
-            );
-
-            if (block_nums.z > 65535) {
-                // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
-                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
-                k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
-                    src0_dd, src1_dd, dst_dd,
-                    ne0, ne1, ne2, ne3,
-                    ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00, */ s01, s02, s03,
-                    /* s10, */ s11, s12, s13);
-            } else {
-                k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
-                    src0_dd, src1_dd, dst_dd,
-                    ne0, ne1, ne2, ne3,
-                    ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00, */ s01, s02, s03,
-                    /* s10, */ s11, s12, s13);
-            }
-        }
+        launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
+            src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
     }
 };
 
@@ -331,6 +346,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
 }
 
+template <float (*op)(const float, const float), int n_fuse>
+static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    cudaStream_t stream = ctx.stream();
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
+            (const float *) src0->data, (const float *) src1->data, (float *) dst->data,
+            stream, std::make_index_sequence<n_fuse>{});
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
+            (const half *) src0->data, (const half *) src1->data, (half *) dst->data,
+            stream, std::make_index_sequence<n_fuse>{});
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+        launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
+            (const half *) src0->data, (const float *) src1->data, (half *) dst->data,
+            stream, std::make_index_sequence<n_fuse>{});
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+        launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
+            (const half *) src0->data, (const float *) src1->data, (float *) dst->data,
+            stream, std::make_index_sequence<n_fuse>{});
+    } else {
+        fprintf(stderr,
+                "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
+                __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ABORT("fatal error");
+    }
+}
+
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
+    GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
+
+    switch (n_fuse) {
+        case 2:
+            ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
+            break;
+        case 3:
+            ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
+            break;
+        case 4:
+            ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
+            break;
+        case 5:
+            ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
+            break;
+        case 6:
+            ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
+            break;
+        case 7:
+            ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
+            break;
+        case 8:
+            ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
+            break;
+        default:
+            GGML_ASSERT(false && "Unsupported n_fuse value");
+    }
+}
+
 void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
 
index 3ac1c9b03fcea7255a124e4a0d9e1a18b3bd8b64..62bc950111b700a13595131eb3930873a29a44eb 100644 (file)
@@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
index 4c02b57227a880d3da4c68918c3769c262960c94..6a1b0fc936092fe17635a00c468b7cd4eb7f0a9f 100644 (file)
@@ -2821,9 +2821,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         return false;
     }
 
-    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
         const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
         const ggml_tensor *mul      = cgraph->nodes[node_idx+1];
+        const ggml_tensor *add      = nullptr;
+
+        if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
+            add = cgraph->nodes[node_idx+1];
+        }
 
         GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
         GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
@@ -2835,6 +2840,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
             return false;
         }
 
+        if (add && (add->src[0]->type != GGML_TYPE_F32 ||
+            add->src[1]->type != GGML_TYPE_F32 ||
+            add->type != GGML_TYPE_F32) ) {
+            return false;
+        }
+
         //if rms norm is the B operand, then we don't handle broadcast
         if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
             return false;
@@ -2845,6 +2856,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
             return false;
         }
 
+        if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
+            return false;
+        }
+
         return true;
     }
 
@@ -2891,7 +2906,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
 
                 static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
                 if (!disable_fusion) {
-                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
+
+                    if (node->op == GGML_OP_ADD) {
+                        int n_fuse = 0;
+                        ggml_op ops[8];
+                        std::fill(ops, ops + 8, GGML_OP_ADD);
+
+                        for (; n_fuse <= 6; ++n_fuse){
+                            if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
+                                break;
+                            }
+                            if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
+                                break;
+                            }
+                            if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
+                                break;
+                            }
+                        }
+
+                        n_fuse++;
+
+                        if (n_fuse > 1) {
+                            for (int j = 0; j < n_fuse - 1; ++j) {
+                                node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
+                            }
+                            cgraph->nodes[i + n_fuse - 1]->data = node->data;
+                            ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
+                            i += n_fuse - 1;
+
+                            continue;
+                        }
+                    }
+
+
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
+                        ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
+                        i += 2;
+                        continue;
+                    }
+
+                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
                         ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
                         i++;
                         continue;
index bddcca51b7bfcb43ea23b7b2780359faac27b875..293f6f68e5e52bd0c5d243ea9b2dcdbd47d03a9f 100644 (file)
@@ -104,12 +104,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
     }
 }
 
-template <int block_size, bool do_multiply = false>
-static __global__ void rms_norm_f32(
-        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
-        const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
-        const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
-        const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
+template <int block_size, bool do_multiply = false, bool do_add = false>
+static __global__ void rms_norm_f32(const float * x, float *       dst,
+                                    const int     ncols,
+                                    const int64_t stride_row,
+                                    const int64_t stride_channel,
+                                    const int64_t stride_sample,
+                                    const float   eps,
+                                    const float * mul                = nullptr,
+                                    const int64_t mul_stride_row     = 0,
+                                    const int64_t mul_stride_channel = 0,
+                                    const int64_t mul_stride_sample  = 0,
+                                    const int     mul_ncols          = 0,
+                                    const int     mul_nrows          = 0,
+                                    const int     mul_nchannels      = 0,
+                                    const int     mul_nsamples       = 0,
+                                    const float * add                = nullptr,
+                                    const int64_t add_stride_row     = 0,
+                                    const int64_t add_stride_channel = 0,
+                                    const int64_t add_stride_sample  = 0,
+                                    const int     add_ncols          = 0,
+                                    const int     add_nrows          = 0,
+                                    const int     add_nchannels      = 0,
+                                    const int     add_nsamples       = 0) {
     const int nrows     = gridDim.x;
     const int nchannels = gridDim.y;
 
@@ -128,6 +145,13 @@ static __global__ void rms_norm_f32(
         mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
     }
 
+    if constexpr (do_add) {
+        const int add_row     = row % add_nrows;
+        const int add_channel = channel % add_nchannels;
+        const int add_sample  = sample % add_nsamples;
+        add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
+    }
+
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += block_size) {
@@ -154,9 +178,16 @@ static __global__ void rms_norm_f32(
     const float scale = rsqrtf(mean + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        if constexpr (do_multiply) {
+        if constexpr (do_multiply && do_add) {
+            const int mul_col = col % mul_ncols;
+            const int add_col = col % add_ncols;
+            dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
+        } else if constexpr (do_multiply) {
             const int mul_col = col % mul_ncols;
             dst[col] = scale * x[col] * mul[mul_col];
+        } else if constexpr (do_add) {
+            const int add_col = col % add_ncols;
+            dst[col] += add[add_col];
         } else {
             dst[col] = scale * x[col];
         }
@@ -331,23 +362,70 @@ static void rms_norm_f32_cuda(
     }
 }
 
-static void rms_norm_mul_f32_cuda(
-        const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
-        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
-        const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
-        const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
-        const float eps, cudaStream_t stream) {
+static void rms_norm_mul_f32_cuda(const float * x,
+                                  const float * mul,
+                                  const float * add,
+                                  float *       dst,
+                                  const int     ncols,
+                                  const int     nrows,
+                                  const int     nchannels,
+                                  const int     nsamples,
+                                  const int64_t stride_row,
+                                  const int64_t stride_channel,
+                                  const int64_t stride_sample,
+                                  const int64_t mul_stride_row,
+                                  const int64_t mul_stride_channel,
+                                  const int64_t mul_stride_sample,
+                                  const int     mul_ncols,
+                                  const int     mul_nrows,
+                                  const int     mul_nchannels,
+                                  const int     mul_nsamples,
+                                  const int64_t add_stride_row,
+                                  const int64_t add_stride_channel,
+                                  const int64_t add_stride_sample,
+                                  const int     add_ncols,
+                                  const int     add_nrows,
+                                  const int     add_nchannels,
+                                  const int     add_nsamples,
+                                  const float   eps,
+                                  cudaStream_t  stream) {
     const dim3 blocks_num(nrows, nchannels, nsamples);
     if (mul == nullptr) {
         rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
         return;
     }
-    if (ncols < 1024) {
-        const dim3 block_dims(WARP_SIZE, 1, 1);
-        rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+    if (add == nullptr) {
+        if (ncols < 1024) {
+            const dim3 block_dims(WARP_SIZE, 1, 1);
+            rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
+                ncols, stride_row, stride_channel, stride_sample, eps,
+                mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+                mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+        } else {
+            const dim3 block_dims(1024, 1, 1);
+            rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
+                ncols, stride_row, stride_channel, stride_sample, eps,
+                mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+                mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+        }
     } else {
-        const dim3 block_dims(1024, 1, 1);
-        rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
+        if (ncols < 1024) {
+            const dim3 block_dims(WARP_SIZE, 1, 1);
+            rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
+                ncols, stride_row, stride_channel, stride_sample, eps,
+                mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+                mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+                add, add_stride_row, add_stride_channel, add_stride_sample,
+                add_ncols, add_nrows, add_nchannels, add_nsamples);
+        } else {
+            const dim3 block_dims(1024, 1, 1);
+            rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
+                ncols, stride_row, stride_channel, stride_sample, eps,
+                mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
+                mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+                add, add_stride_row, add_stride_channel, add_stride_sample,
+                add_ncols, add_nrows, add_nchannels, add_nsamples);
+        }
     }
 }
 
@@ -491,7 +569,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
     const int mul_nchannels = mul_src->ne[2];
     const int mul_nsamples  = mul_src->ne[3];
 
-    rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
+    rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
+                          ne00, ne01, ne02, ne03,
+                          /*s00*/ s01, s02, s03,
+                          /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+                          mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+                          /*add_s00*/ 0, 0, 0,
+                          0, 0, 0, 0,
+                          eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+                                     ggml_tensor *               dst,
+                                     ggml_tensor *               mul_tensor,
+                                     ggml_tensor *               add_tensor) {
+    const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+    float               eps          = 0.0f;
+
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    const float *       src0_d  = (const float *) rms_norm_src->data;
+    const float *       mul_d   = nullptr;
+    const ggml_tensor * mul_src = nullptr;
+
+    if (mul_tensor->src[0] == dst) {
+        mul_d   = (float *) mul_tensor->src[1]->data;
+        mul_src = mul_tensor->src[1];
+    } else if (mul_tensor->src[1] == dst) {
+        mul_d   = (float *) mul_tensor->src[0]->data;
+        mul_src = mul_tensor->src[0];
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    const float *       add_d   = nullptr;
+    const ggml_tensor * add_src = nullptr;
+
+    if (add_tensor->src[0] == mul_tensor) {
+        add_d   = (float *) add_tensor->src[1]->data;
+        add_src = add_tensor->src[1];
+    } else if (add_tensor->src[1] == mul_tensor) {
+        add_d   = (float *) add_tensor->src[0]->data;
+        add_src = add_tensor->src[0];
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    float *      dst_d  = (float *) add_tensor->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+    GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
+    GGML_ASSERT(eps >= 0.0f);
+
+    const int64_t ne00 = rms_norm_src->ne[0];
+    const int64_t ne01 = rms_norm_src->ne[1];
+    const int64_t ne02 = rms_norm_src->ne[2];
+    const int64_t ne03 = rms_norm_src->ne[3];
+
+    const size_t ts0 = ggml_type_size(rms_norm_src->type);
+    GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+    const int64_t s01 = rms_norm_src->nb[1] / ts0;
+    const int64_t s02 = rms_norm_src->nb[2] / ts0;
+    const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+    const size_t ts_mul = ggml_type_size(mul_src->type);
+    GGML_ASSERT(mul_src->nb[0] == ts_mul);
+    const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+    const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+    const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+    const int mul_ncols     = mul_src->ne[0];
+    const int mul_nrows     = mul_src->ne[1];
+    const int mul_nchannels = mul_src->ne[2];
+    const int mul_nsamples  = mul_src->ne[3];
+
+    const size_t ts_add = ggml_type_size(add_src->type);
+    GGML_ASSERT(add_src->nb[0] == ts_add);
+    const int64_t add_s01 = add_src->nb[1] / ts_add;
+    const int64_t add_s02 = add_src->nb[2] / ts_add;
+    const int64_t add_s03 = add_src->nb[3] / ts_add;
+
+    const int add_ncols     = add_src->ne[0];
+    const int add_nrows     = add_src->ne[1];
+    const int add_nchannels = add_src->ne[2];
+    const int add_nsamples  = add_src->ne[3];
+
+    rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
+                          ne00,ne01, ne02, ne03,
+                          /*s00*/ s01, s02, s03,
+                          /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+                          mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+                          /*add_s00*/ add_s01, add_s02, add_s03,
+                          add_ncols, add_nrows, add_nchannels, add_nsamples,
+                          eps, stream);
 }
 
 void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
index 7ea7bd4df3cc6896460b55fbc9a133172f333f78..a74f6376720ab60b07b6a036042bdedca61f4e4a 100644 (file)
@@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
 
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+                                     ggml_tensor *               dst,
+                                     ggml_tensor *               mul_tensor,
+                                     ggml_tensor *               add_tensor);
+
 void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);