]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: add fused rms norm (llama/14800)
authorAman Gupta <redacted>
Wed, 23 Jul 2025 01:25:42 +0000 (09:25 +0800)
committerGeorgi Gerganov <redacted>
Thu, 24 Jul 2025 17:57:40 +0000 (20:57 +0300)
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/norm.cu
src/ggml-cuda/norm.cuh
tests/test-backend-ops.cpp

index 548bc31ce215802791ed65fb9339b1f498160e86..03c380897cd8a9e29e3a44d2343add84a4abad7e 100644 (file)
@@ -55,6 +55,7 @@
 #include <cstddef>
 #include <cstdint>
 #include <float.h>
+#include <initializer_list>
 #include <limits>
 #include <map>
 #include <memory>
@@ -2765,6 +2766,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 }
 #endif
 
+static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
+    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+        return false;
+    }
+
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
+        const ggml_tensor *mul      = cgraph->nodes[node_idx+1];
+
+        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
+        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
+
+        //rms norm only supports F32
+        if (mul->src[0]->type != GGML_TYPE_F32 ||
+            mul->src[1]->type != GGML_TYPE_F32 ||
+            mul->type != GGML_TYPE_F32) {
+            return false;
+        }
+
+        //if rms norm is the B operand, then we don't handle broadcast
+        if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
+            return false;
+        }
+
+        //rms_norm kernel assumes contigous rows
+        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
+            return false;
+        }
+    }
+
+    return true;
+}
+
 static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
     bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
     // flag used to determine whether it is an integrated_gpu
@@ -2774,6 +2808,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
         // With the use of CUDA graphs, the execution will be performed by the graph launch.
         if (!use_cuda_graph || cuda_graph_update_required) {
+
             for (int i = 0; i < cgraph->n_nodes; i++) {
                 ggml_tensor * node = cgraph->nodes[i];
 
@@ -2781,6 +2816,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                     continue;
                 }
 
+                static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
+                if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+                    ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
+                    i++;
+                    continue;
+                }
 #ifndef NDEBUG
                 assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
                 for (int j = 0; j < GGML_MAX_SRC; j++) {
index 0020dbcec5fb59dae4d157b1771935ac65b5c868..bddcca51b7bfcb43ea23b7b2780359faac27b875 100644 (file)
@@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
     }
 }
 
-template <int block_size>
+template <int block_size, bool do_multiply = false>
 static __global__ void rms_norm_f32(
         const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
-        const int64_t stride_sample, const float eps) {
+        const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0,
+        const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0,
+        const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) {
     const int nrows     = gridDim.x;
     const int nchannels = gridDim.y;
 
@@ -119,6 +121,13 @@ static __global__ void rms_norm_f32(
     x   += sample*stride_sample + channel*stride_channel + row*stride_row;
     dst += ((sample*nchannels + channel)*nrows + row)*ncols;
 
+    if constexpr (do_multiply) {
+        const int mul_row = row % mul_nrows;
+        const int mul_channel = channel % mul_nchannels;
+        const int mul_sample = sample % mul_nsamples;
+        mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
+    }
+
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += block_size) {
@@ -145,7 +154,12 @@ static __global__ void rms_norm_f32(
     const float scale = rsqrtf(mean + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[col] = scale * x[col];
+        if constexpr (do_multiply) {
+            const int mul_col = col % mul_ncols;
+            dst[col] = scale * x[col] * mul[mul_col];
+        } else {
+            dst[col] = scale * x[col];
+        }
     }
 }
 
@@ -310,10 +324,30 @@ static void rms_norm_f32_cuda(
     const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    }
+}
+
+static void rms_norm_mul_f32_cuda(
+        const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
+        const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample,
+        const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples,
+        const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
+    if (mul == nullptr) {
+        rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
+        return;
+    }
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+        rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
     }
 }
 
@@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
 }
 
+void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
+    const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+    float eps = 0.0f;
+
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    const float * src0_d = (const float *) rms_norm_src->data;
+    const float * mul_d = nullptr;
+    const ggml_tensor * mul_src = nullptr;
+
+    if (mul_tensor->src[0] == dst) {
+        mul_d = (float *) mul_tensor->src[1]->data;
+        mul_src = mul_tensor->src[1];
+    } else if(mul_tensor->src[1] == dst) {
+        mul_d = (float *) mul_tensor->src[0]->data;
+        mul_src = mul_tensor->src[0];
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    float * dst_d = (float *) mul_tensor->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+    GGML_ASSERT(eps >= 0.0f);
+
+    const int64_t ne00 = rms_norm_src->ne[0];
+    const int64_t ne01 = rms_norm_src->ne[1];
+    const int64_t ne02 = rms_norm_src->ne[2];
+    const int64_t ne03 = rms_norm_src->ne[3];
+
+    const size_t ts0 = ggml_type_size(rms_norm_src->type);
+    GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+    const int64_t s01 = rms_norm_src->nb[1] / ts0;
+    const int64_t s02 = rms_norm_src->nb[2] / ts0;
+    const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+    const size_t ts_mul = ggml_type_size(mul_src->type);
+    GGML_ASSERT(mul_src->nb[0] == ts_mul);
+    const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+    const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+    const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+    const int mul_ncols     = mul_src->ne[0];
+    const int mul_nrows     = mul_src->ne[1];
+    const int mul_nchannels = mul_src->ne[2];
+    const int mul_nsamples  = mul_src->ne[3];
+
+    rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream);
+}
+
 void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * grad  = dst->src[0]; // gradients
     const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
index 706a5660a680cb57de910dbb78e5d527d5fea09a..7ea7bd4df3cc6896460b55fbc9a133172f333f78 100644 (file)
@@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
+
 void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index a6d00542dd21efd19c1cd305e4313eb8c540b1bc..4898094c918e16446b9eaf324370d74ddca477fc 100644 (file)
@@ -2641,6 +2641,7 @@ struct test_rms_norm_mul_add : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
     const float eps;
+    const bool broadcast;
 
     std::string op_desc(ggml_tensor * t) override {
         GGML_UNUSED(t);
@@ -2650,18 +2651,21 @@ struct test_rms_norm_mul_add : public test_case {
     bool run_whole_graph() override { return true; }
 
     std::string vars() override {
-        return VARS_TO_STR3(type, ne, eps);
+        return VARS_TO_STR4(type, ne, eps, broadcast);
     }
 
     test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {64, 5, 4, 3},
-            float eps = 1e-6f)
-        : type(type), ne(ne), eps(eps) {}
+            float eps = 1e-6f, bool broadcast = false)
+        : type(type), ne(ne), eps(eps), broadcast(broadcast) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
+
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data());
+
         ggml_set_param(a);
         ggml_set_name(a, "a");
         ggml_set_param(b);
@@ -5354,6 +5358,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     }
     for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
         test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+        test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
     }
 
     test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));