]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: add mean operation (llama/14313)
authorAman Gupta <redacted>
Sun, 22 Jun 2025 04:39:54 +0000 (12:39 +0800)
committerGeorgi Gerganov <redacted>
Tue, 1 Jul 2025 08:52:14 +0000 (11:52 +0300)
* CUDA: add mean operation

* add back sum_rows_f32_cuda

* Review: early exit if col!=0

src/ggml-cuda/common.cuh
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/mean.cu [new file with mode: 0644]
src/ggml-cuda/mean.cuh [new file with mode: 0644]
src/ggml-cuda/sumrows.cu
src/ggml-cuda/sumrows.cuh
tests/test-backend-ops.cpp

index 364efcaeccc0796980417beae7b56b348b7f222e..2f2fce0677066831933cea10f396f5cbe592df75 100644 (file)
@@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #endif // FP16_AVAILABLE
 }
 
+// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
+template<bool norm>
+static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockIdx.x;
+    const int col = threadIdx.x;
+
+    float sum = 0.0f;
+    for (int i = col; i < ncols; i += blockDim.x) {
+        sum += x[row * ncols + i];
+    }
+
+    sum = warp_reduce_sum(sum);
+
+    if (col != 0) {
+        return;
+    }
+
+    dst[row] = norm ? sum / ncols : sum;
+}
+
 template<int width = WARP_SIZE>
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
index 5bab92e347a7e1790c1a43dedf71ae7db0c84021..c6bdd4fb3021f21527910250cb346e8998b3802e 100644 (file)
@@ -37,6 +37,7 @@
 #include "ggml-cuda/ssm-scan.cuh"
 #include "ggml-cuda/sum.cuh"
 #include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/mean.cuh"
 #include "ggml-cuda/tsembd.cuh"
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
@@ -2357,6 +2358,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SUM_ROWS:
             ggml_cuda_op_sum_rows(ctx, dst);
             break;
+        case GGML_OP_MEAN:
+            ggml_cuda_op_mean(ctx, dst);
+            break;
         case GGML_OP_SSM_CONV:
             ggml_cuda_op_ssm_conv(ctx, dst);
             break;
@@ -3260,6 +3264,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
+        case GGML_OP_MEAN:
         case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
             return true;
diff --git a/src/ggml-cuda/mean.cu b/src/ggml-cuda/mean.cu
new file mode 100644 (file)
index 0000000..4b238a3
--- /dev/null
@@ -0,0 +1,19 @@
+#include "mean.cuh"
+
+void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0   = dst->src[0];
+    const float *       src0_d = (const float *) src0->data;
+    float *             dst_d  = (float *) dst->data;
+    cudaStream_t        stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(nrows, 1, 1);
+    reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+}
diff --git a/src/ggml-cuda/mean.cuh b/src/ggml-cuda/mean.cuh
new file mode 100644 (file)
index 0000000..2b9b104
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 38dbf1b5e1fa9d9a484f55848ac9b21343d029e1..2eee08fa073754e994960d4779e825478beb8efb 100644 (file)
@@ -1,25 +1,9 @@
 #include "sumrows.cuh"
 
-static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
-    const int row = blockIdx.x;
-    const int col = threadIdx.x;
-
-    float sum = 0.0f;
-    for (int i = col; i < ncols; i += blockDim.x) {
-        sum += x[row * ncols + i];
-    }
-
-    sum = warp_reduce_sum(sum);
-
-    if (col == 0) {
-        dst[row] = sum;
-    }
-}
-
 void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     const dim3 block_dims(WARP_SIZE, 1, 1);
     const dim3 block_nums(nrows, 1, 1);
-    k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+    reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
 }
 
 void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t ncols = src0->ne[0];
     const int64_t nrows = ggml_nrows(src0);
 
-    sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(nrows, 1, 1);
+
+    reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
 }
index 191db1c13167e4d6e9ac3acfff5b919577c0957b..3431c599b1b89847bfc9c01e1cf7b062e0a534d7 100644 (file)
@@ -1,5 +1,4 @@
 #include "common.cuh"
 
 void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
-
 void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index b64a78159d93d3e07babc96e6be72a2b4ad8a95e..d878c6962536a59b85ceb139b5097a23f5293c95 100644 (file)
@@ -4654,6 +4654,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
 
     test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1));
 
+    test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
+
     return test_cases;
 }