]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix bug in rms_norm fusion (#15660)
authorAman Gupta <redacted>
Fri, 29 Aug 2025 13:30:06 +0000 (21:30 +0800)
committerGitHub <redacted>
Fri, 29 Aug 2025 13:30:06 +0000 (21:30 +0800)
* CUDA: fix bug in rms_norm fusion

* Fix bug for OP_REPEAT

* Fix index for add

ggml/src/ggml-cuda/binbcast.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/norm.cu

index 99a98fcbfcdb36fe364b575929c556da93eb910d..1c76566344a884493ba77fe6a0e7b0697b3bed32 100644 (file)
@@ -57,7 +57,11 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
         const int i10 = i0 % ne10;
 
         float result = src0_row ? (float) src0_row[i0] : 0.0f;
-        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+        if constexpr (sizeof...(src1_ptrs) > 0) {
+            result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+        } else {
+            result = bin_op(result, (float)src1[i_src1 + i10]);
+        }
 
         dst_row[i0] = (dst_t) result;
     }
@@ -96,7 +100,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t *   src0, const src1_t *
     const int i10 = i0 % ne10;
 
     float result = src0_row ? (float) src0_row[i0] : 0.0f;
-    result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+    if constexpr (sizeof...(src1_ptrs) > 0) {
+        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
+    } else {
+        result = bin_op(result, (float)src1[i_src1 + i10]);
+    }
 
     dst_row[i0] = (dst_t) result;
 }
@@ -231,23 +239,43 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
 
         if (block_nums.z > 65535) {
             int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
-            k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
-                <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
-                    ne0, ne1, ne2, ne3,
-                    ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12,s13,
-                    (const src1_t *) dst->src[I + 1]->data...);
+            if constexpr (sizeof...(I) > 0) {
+                k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
+                    <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
+                        ne0, ne1, ne2, ne3,
+                        ne10, ne11, ne12, ne13,
+                        /* s0, */ s1, s2, s3,
+                        /* s00,*/ s01, s02, s03,
+                        /* s10,*/ s11, s12,s13,
+                        (const src1_t *) dst->src[I + 1]->data...);
+            } else {
+                k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
+                    <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
+                        ne0, ne1, ne2, ne3,
+                        ne10, ne11, ne12, ne13,
+                        /* s0, */ s1, s2, s3,
+                        /* s00,*/ s01, s02, s03,
+                        /* s10,*/ s11, s12,s13);
+            }
         } else {
-            k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
-                <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
-                    ne0, ne1, ne2, ne3,
-                    ne10, ne11, ne12, ne13,
-                    /* s0, */ s1, s2, s3,
-                    /* s00,*/ s01, s02, s03,
-                    /* s10,*/ s11, s12,s13,
-                    (const src1_t *) dst->src[I + 1]->data...);
+            if constexpr (sizeof...(I) > 0) {
+                k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
+                    <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
+                        ne0, ne1, ne2, ne3,
+                        ne10, ne11, ne12, ne13,
+                        /* s0, */ s1, s2, s3,
+                        /* s00,*/ s01, s02, s03,
+                        /* s10,*/ s11, s12,s13,
+                        (const src1_t *) dst->src[I + 1]->data...);
+            } else {
+                k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
+                    <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
+                        ne0, ne1, ne2, ne3,
+                        ne10, ne11, ne12, ne13,
+                        /* s0, */ s1, s2, s3,
+                        /* s00,*/ s01, s02, s03,
+                        /* s10,*/ s11, s12,s13);
+            }
         }
     }
 }
@@ -327,7 +355,7 @@ static void ggml_cuda_op_bin_bcast(
 }
 
 void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
+    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
 }
 
 void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
index 6a1b0fc936092fe17635a00c468b7cd4eb7f0a9f..e06f95f0819ed71768a39d3c26fe2721fc10e569 100644 (file)
@@ -2827,7 +2827,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         const ggml_tensor *add      = nullptr;
 
         if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
-            add = cgraph->nodes[node_idx+1];
+            add = cgraph->nodes[node_idx+2];
         }
 
         GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
index 293f6f68e5e52bd0c5d243ea9b2dcdbd47d03a9f..d5157d958b717926497ac8fc4eb251b39e55acda 100644 (file)
@@ -127,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x, float *       dst,
                                     const int     add_nrows          = 0,
                                     const int     add_nchannels      = 0,
                                     const int     add_nsamples       = 0) {
+
     const int nrows     = gridDim.x;
     const int nchannels = gridDim.y;
 
@@ -135,6 +136,8 @@ static __global__ void rms_norm_f32(const float * x, float *       dst,
     const int sample    = blockIdx.z;
     const int tid       = threadIdx.x;
 
+    static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
+
     x   += sample*stride_sample + channel*stride_channel + row*stride_row;
     dst += ((sample*nchannels + channel)*nrows + row)*ncols;
 
@@ -185,9 +188,6 @@ static __global__ void rms_norm_f32(const float * x, float *       dst,
         } else if constexpr (do_multiply) {
             const int mul_col = col % mul_ncols;
             dst[col] = scale * x[col] * mul[mul_col];
-        } else if constexpr (do_add) {
-            const int add_col = col % add_ncols;
-            dst[col] += add[add_col];
         } else {
             dst[col] = scale * x[col];
         }